Skip to content

Commit

Permalink
feat: add ZBytes.serialize
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Sep 9, 2024
1 parent 528fdcd commit a97c932
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 57 deletions.
3 changes: 2 additions & 1 deletion examples/z_pub_thr.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def main():
data = bytearray()
for i in range(0, size):
data.append(i % 10)
data = zenoh.ZBytes(bytes(data))
congestion_control = zenoh.CongestionControl.BLOCK

with zenoh.open(conf) as session:
Expand All @@ -89,7 +90,7 @@ def main():

print("Press CTRL-C to quit...")
while True:
pub.put(bytes(data))
pub.put(data)


if __name__ == "__main__":
Expand Down
101 changes: 58 additions & 43 deletions src/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,45 +93,8 @@ pub(crate) fn serializer(
}
}

#[pyfunction]
#[pyo3(signature = (func = None, /, *, target = None))]
pub(crate) fn deserializer(
py: Python,
func: Option<&Bound<PyAny>>,
target: Option<&Bound<PyAny>>,
) -> PyResult<PyObject> {
match (func, target) {
(Some(func), Some(target)) => {
deserializers(py).bind(py).set_item(target, func)?;
Ok(py.None())
}
(Some(func), None) => match get_type(func, "return") {
Ok(target) => deserializer(py, Some(func), Some(&target)),
_ => Err(PyValueError::new_err(
"Cannot extract target from deserializer signature",
)),
},
(None, Some(target)) => {
let target = target.clone().unbind();
let closure = PyCFunction::new_closure_bound(py, None, None, move |args, _| {
let Ok(func) = args.get_item(0) else {
return Err(PyTypeError::new_err("expected one positional argument"));
};
deserializer(args.py(), Some(&func), Some(target.bind(args.py())))
})?;
Ok(closure.into_any().unbind())
}
(None, None) => Err(PyTypeError::new_err("missing 'func' or 'target' parameter")),
}
}

wrapper!(zenoh::bytes::ZBytes: Clone, Default);
downcast_or_new!(ZBytes);

#[pymethods]
impl ZBytes {
#[new]
fn new(obj: Option<&Bound<PyAny>>) -> PyResult<Self> {
fn serialize_impl(obj: Option<&Bound<PyAny>>) -> PyResult<Self> {
let Some(obj) = obj else {
return Ok(Self::default());
};
Expand All @@ -152,13 +115,17 @@ impl ZBytes {
} else if let Ok(list) = obj.downcast::<PyList>() {
try_process(
list.iter()
.map(|elt| PyResult::Ok(Self::new(Some(&elt))?.0)),
.map(|elt| PyResult::Ok(Self::serialize_impl(Some(&elt))?.0)),
|iter| iter.collect(),
)?
} else if let Ok(dict) = obj.downcast::<PyDict>() {
try_process(
dict.iter()
.map(|(k, v)| PyResult::Ok((Self::new(Some(&k))?.0, Self::new(Some(&v))?.0))),
dict.iter().map(|(k, v)| {
PyResult::Ok((
Self::serialize_impl(Some(&k))?.0,
Self::serialize_impl(Some(&v))?.0,
))
}),
|iter| iter.collect(),
)?
} else if let Ok(tuple) = obj.downcast::<PyTuple>() {
Expand All @@ -168,8 +135,8 @@ impl ZBytes {
));
}
zenoh::bytes::ZBytes::serialize((
Self::new(Some(&tuple.get_item(0)?))?,
Self::new(Some(&tuple.get_item(1)?))?,
Self::serialize_impl(Some(&tuple.get_item(0)?))?,
Self::serialize_impl(Some(&tuple.get_item(1)?))?,
))
} else if let Ok(Some(ser)) = serializers(py).bind(py).get_item(obj.get_type()) {
return match ZBytes::extract_bound(&ser.call1((obj,))?) {
Expand All @@ -185,6 +152,54 @@ impl ZBytes {
));
}))
}
}

#[pyfunction]
#[pyo3(signature = (func = None, /, *, target = None))]
pub(crate) fn deserializer(
py: Python,
func: Option<&Bound<PyAny>>,
target: Option<&Bound<PyAny>>,
) -> PyResult<PyObject> {
match (func, target) {
(Some(func), Some(target)) => {
deserializers(py).bind(py).set_item(target, func)?;
Ok(py.None())
}
(Some(func), None) => match get_type(func, "return") {
Ok(target) => deserializer(py, Some(func), Some(&target)),
_ => Err(PyValueError::new_err(
"Cannot extract target from deserializer signature",
)),
},
(None, Some(target)) => {
let target = target.clone().unbind();
let closure = PyCFunction::new_closure_bound(py, None, None, move |args, _| {
let Ok(func) = args.get_item(0) else {
return Err(PyTypeError::new_err("expected one positional argument"));
};
deserializer(args.py(), Some(&func), Some(target.bind(args.py())))
})?;
Ok(closure.into_any().unbind())
}
(None, None) => Err(PyTypeError::new_err("missing 'func' or 'target' parameter")),
}
}

wrapper!(zenoh::bytes::ZBytes: Clone, Default);
downcast_or_new!(serialize_impl: ZBytes);

#[pymethods]
impl ZBytes {
#[new]
fn new(bytes: Option<&Bound<PyBytes>>) -> Self {
bytes.map_or_else(Self::default, |b| Self(b.as_bytes().into()))
}

#[classmethod]
fn serialize(_cls: &Bound<PyType>, obj: Option<&Bound<PyAny>>) -> PyResult<Self> {
Self::serialize_impl(obj)
}

fn deserialize(this: PyRef<Self>, tp: &Bound<PyAny>) -> PyResult<PyObject> {
let py = tp.py();
Expand Down
7 changes: 5 additions & 2 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ macro_rules! bail {
pub(crate) use bail;

macro_rules! downcast_or_new {
($ty:ty $(=> $new:ty)? $(, $other:expr)?) => {
($method:ident: $ty:ty $(=> $new:ty)? $(, $other:expr)?) => {
#[allow(unused)]
impl $ty {
pub(crate) fn from_py(obj: &Bound<PyAny>) -> PyResult<Self> {
if let Ok(obj) = <Self as pyo3::FromPyObject>::extract_bound(obj) {
return Ok(obj);
}
Self::new(PyResult::Ok(obj)$(.and_then(<$new>::extract_bound))??.into(), $($other)?)
Self::$method(PyResult::Ok(obj)$(.and_then(<$new>::extract_bound))??.into(), $($other)?)
}
pub(crate) fn from_py_opt(obj: &Bound<PyAny>) -> PyResult<Option<Self>> {
if obj.is_none() {
Expand All @@ -79,6 +79,9 @@ macro_rules! downcast_or_new {
}
}
};
($ty:ty $(=> $new:ty)? $(, $other:expr)?) => {
$crate::macros::downcast_or_new!(new: $ty $(=> $new)? $(, $other)?);
};
}
pub(crate) use downcast_or_new;

Expand Down
16 changes: 8 additions & 8 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
(int, 42),
(float, 0.5),
(bool, True),
(ZBytes, ZBytes(b"foo")),
(list, [ZBytes(0), ZBytes(1)]),
(dict, {ZBytes("foo"): ZBytes("bar")}),
(ZBytes, ZBytes.serialize(b"foo")),
(list, [ZBytes.serialize(0), ZBytes.serialize(1)]),
(dict, {ZBytes.serialize("foo"): ZBytes.serialize("bar")}),
]
if sys.version_info >= (3, 9):
default_serializer_tests = [
Expand All @@ -40,7 +40,7 @@

@pytest.mark.parametrize("tp, value", default_serializer_tests)
def test_default_serializer(tp, value):
assert ZBytes(value).deserialize(tp) == value
assert ZBytes.serialize(value).deserialize(tp) == value


def test_registered_serializer():
Expand All @@ -54,10 +54,10 @@ def deserialize_foo(zbytes: ZBytes) -> Foo:

@serializer
def serialize_foo(foo: Foo) -> ZBytes:
return ZBytes(foo.bar)
return ZBytes.serialize(foo.bar)

foo = Foo(42)
assert ZBytes(foo).deserialize(Foo) == foo
assert ZBytes.serialize(foo).deserialize(Foo) == foo


def test_registered_serializer_with_target():
Expand All @@ -71,7 +71,7 @@ def deserialize_foo(zbytes):

@serializer(target=Foo)
def serialize_foo(foo):
return ZBytes(foo.bar)
return ZBytes.serialize(foo.bar)

foo = Foo(42)
assert ZBytes(foo).deserialize(Foo) == foo
assert ZBytes.serialize(foo).deserialize(Foo) == foo
11 changes: 8 additions & 3 deletions zenoh/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -962,10 +962,15 @@ _IntoWhatAmIMatcher = WhatAmIMatcher | str
class ZBytes:
"""ZBytes contains the serialized bytes of user data."""

def __new__(cls, obj: Any = None) -> Self: ...
def __new__(cls, bytes: bytes = None) -> Self: ...
@classmethod
def serialize(cls, obj: Any) -> Self:
"""Serialize object according to its type,
using default or registered serializer."""

def deserialize(self, tp: type[_T]) -> _T:
"""Deserialize payload to the given types,
using default or registered deserializer"""
"""Deserialize bytes to the given types,
using default or registered deserializer."""

def __bool__(self) -> bool: ...
def __len__(self) -> int: ...
Expand Down

0 comments on commit a97c932

Please sign in to comment.