Skip to content

Commit

Permalink
refactor bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
s0l0ist committed Oct 2, 2024
1 parent e9f4058 commit 4da5d46
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
os: [ubuntu-24.04, ubuntu-22.04, ubuntu-20.04, macos-14]
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ jobs:
os:
[ubuntu-24.04, macos-14]
# Bazel uses hermetic python, these are just placeholders
python-version: ['3_8', '3_9', '3_10', '3_11']
python-version: ['3_8', '3_9', '3_10', '3_11', '3_12']
steps:
- uses: actions/checkout@v4
# configuring python for bazel abi and platform repo rules
Expand Down
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Version 2.0.5

Feat:

- Add support for python 3.12

# Version 2.0.4

Chore:
Expand Down
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module(
name = "org_openmined_psi",
version = "2.0.4",
version = "2.0.5",
)

http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
Expand Down
6 changes: 3 additions & 3 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@openmined/psi.js",
"version": "2.0.4",
"version": "2.0.5",
"description": "Private Set Intersection for JavaScript",
"repository": {
"type": "git",
Expand Down
167 changes: 107 additions & 60 deletions private_set_intersection/python/psi_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ namespace py = pybind11;

template <class T>
T throwOrReturn(const absl::StatusOr<T>& in) {
if (!in.ok()) throw std::runtime_error(std::string(in.status().message()));
if (!in.ok()) {
py::gil_scoped_acquire acquire;
throw std::runtime_error(std::string(in.status().message()));
}
return *in;
}

Expand All @@ -29,6 +32,7 @@ auto saveProto(const T& obj) {
template <class T>
auto loadProto(T& obj, const py::bytes& data) {
if (!obj.ParseFromString(data)) {
py::gil_scoped_acquire acquire;
throw std::invalid_argument("failed to parse proto data");
}
}
Expand All @@ -47,74 +51,107 @@ void bind(pybind11::module& m) {

py::class_<psi_proto::ServerSetup>(m, "cpp_proto_server_setup")
.def(py::init<>())
.def("load", [](psi_proto::ServerSetup& obj,
const py::bytes& data) { return loadProto(obj, data); })
.def(
"load",
[](psi_proto::ServerSetup& obj, const py::bytes& data) {
return loadProto(obj, data);
},
py::call_guard<py::gil_scoped_release>())
.def("save",
[](const psi_proto::ServerSetup& obj) { return saveProto(obj); })
.def_static("Load", [](const py::bytes& data) {
psi_proto::ServerSetup obj;
loadProto(obj, data);
return obj;
});
.def_static(
"Load",
[](const py::bytes& data) {
psi_proto::ServerSetup obj;
loadProto(obj, data);
return obj;
},
py::call_guard<py::gil_scoped_release>());
py::class_<psi_proto::Request>(m, "cpp_proto_request")
.def(py::init<>())
.def("load", [](psi_proto::Request& obj,
const py::bytes& data) { return loadProto(obj, data); })
.def(
"load",
[](psi_proto::Request& obj, const py::bytes& data) {
return loadProto(obj, data);
},
py::call_guard<py::gil_scoped_release>())
.def("save", [](const psi_proto::Request& obj) { return saveProto(obj); })
.def_static("Load", [](const py::bytes& data) {
psi_proto::Request obj;
loadProto(obj, data);
return obj;
});
.def_static(
"Load",
[](const py::bytes& data) {
psi_proto::Request obj;
loadProto(obj, data);
return obj;
},
py::call_guard<py::gil_scoped_release>());
py::class_<psi_proto::Response>(m, "cpp_proto_response")
.def(py::init<>())
.def("load", [](psi_proto::Response& obj,
const py::bytes& data) { return loadProto(obj, data); })
.def(
"load",
[](psi_proto::Response& obj, const py::bytes& data) {
return loadProto(obj, data);
},
py::call_guard<py::gil_scoped_release>())
.def("save",
[](const psi_proto::Response& obj) { return saveProto(obj); })
.def_static("Load", [](const py::bytes& data) {
psi_proto::Response obj;
loadProto(obj, data);
return obj;
});
.def_static(
"Load",
[](const py::bytes& data) {
psi_proto::Response obj;
loadProto(obj, data);
return obj;
},
py::call_guard<py::gil_scoped_release>());

py::class_<psi::PsiClient>(m, "cpp_client")
.def_static(
"CreateWithNewKey",
[](bool reveal_intersection) {
auto client = psi::PsiClient::CreateWithNewKey(reveal_intersection);
if (!client.ok())
if (!client.ok()) {
py::gil_scoped_acquire acquire;
throw std::runtime_error(std::string(client.status().message()));
}
return std::move(*client);
})
},
py::call_guard<py::gil_scoped_release>())
.def_static(
"CreateFromKey",
[](const std::string& key_bytes, bool reveal_intersection) {
auto client =
psi::PsiClient::CreateFromKey(key_bytes, reveal_intersection);
if (!client.ok())
if (!client.ok()) {
py::gil_scoped_acquire acquire;
throw std::runtime_error(std::string(client.status().message()));
}
return std::move(*client);
})
.def("CreateRequest",
[](const psi::PsiClient& obj,
const std::vector<std::string>& inputs) {
return throwOrReturn(obj.CreateRequest(absl::MakeSpan(inputs)));
})
.def("GetIntersection",
[](const psi::PsiClient& obj,
const psi_proto::ServerSetup& server_setup,
const psi_proto::Response& server_response) {
return throwOrReturn(
obj.GetIntersection(server_setup, server_response));
})
.def("GetIntersectionSize",
[](const psi::PsiClient& obj,
const psi_proto::ServerSetup& server_setup,
const psi_proto::Response& server_response) {
return throwOrReturn(
obj.GetIntersectionSize(server_setup, server_response));
})
},
py::call_guard<py::gil_scoped_release>())
.def(
"CreateRequest",
[](const psi::PsiClient& obj,
const std::vector<std::string>& inputs) {
return throwOrReturn(obj.CreateRequest(absl::MakeSpan(inputs)));
},
py::call_guard<py::gil_scoped_release>())
.def(
"GetIntersection",
[](const psi::PsiClient& obj,
const psi_proto::ServerSetup& server_setup,
const psi_proto::Response& server_response) {
return throwOrReturn(
obj.GetIntersection(server_setup, server_response));
},
py::call_guard<py::gil_scoped_release>())
.def(
"GetIntersectionSize",
[](const psi::PsiClient& obj,
const psi_proto::ServerSetup& server_setup,
const psi_proto::Response& server_response) {
return throwOrReturn(
obj.GetIntersectionSize(server_setup, server_response));
},
py::call_guard<py::gil_scoped_release>())
.def("GetPrivateKeyBytes", [](const psi::PsiClient& obj) {
return py::bytes(obj.GetPrivateKeyBytes());
});
Expand All @@ -124,30 +161,40 @@ void bind(pybind11::module& m) {
"CreateWithNewKey",
[](bool reveal_intersection) {
auto server = psi::PsiServer::CreateWithNewKey(reveal_intersection);
if (!server.ok())
if (!server.ok()) {
py::gil_scoped_acquire acquire;
throw std::runtime_error(std::string(server.status().message()));
}
return std::move(*server);
})
},
py::call_guard<py::gil_scoped_release>())
.def_static(
"CreateFromKey",
[](const std::string& key_bytes, bool reveal_intersection) {
auto server =
psi::PsiServer::CreateFromKey(key_bytes, reveal_intersection);
if (!server.ok())
if (!server.ok()) {
py::gil_scoped_acquire acquire;
throw std::runtime_error(std::string(server.status().message()));
}
return std::move(*server);
})
.def("CreateSetupMessage",
[](const psi::PsiServer& obj, double fpr, int64_t num_client_inputs,
const std::vector<std::string>& inputs, psi::DataStructure ds) {
return throwOrReturn(obj.CreateSetupMessage(
fpr, num_client_inputs, absl::MakeSpan(inputs), ds));
})
.def("ProcessRequest",
[](const psi::PsiServer& obj,
const psi_proto::Request& client_request) {
return throwOrReturn(obj.ProcessRequest(client_request));
})
},
py::call_guard<py::gil_scoped_release>())
.def(
"CreateSetupMessage",
[](const psi::PsiServer& obj, double fpr, int64_t num_client_inputs,
const std::vector<std::string>& inputs, psi::DataStructure ds) {
return throwOrReturn(obj.CreateSetupMessage(
fpr, num_client_inputs, absl::MakeSpan(inputs), ds));
},
py::call_guard<py::gil_scoped_release>())
.def(
"ProcessRequest",
[](const psi::PsiServer& obj,
const psi_proto::Request& client_request) {
return throwOrReturn(obj.ProcessRequest(client_request));
},
py::call_guard<py::gil_scoped_release>())
.def("GetPrivateKeyBytes", [](const psi::PsiServer& obj) {
return py::bytes(obj.GetPrivateKeyBytes());
});
Expand Down
2 changes: 1 addition & 1 deletion tools/package.bzl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
""" Version of the current release """
VERSION_LABEL = "2.0.4"
VERSION_LABEL = "2.0.5"

0 comments on commit 4da5d46

Please sign in to comment.