Skip to content

Commit

Permalink
Implement util module with socket helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
miccoli committed Aug 15, 2023
1 parent 17f2220 commit d35f5b3
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/trick17/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-FileCopyrightText: 2023-present Stefano Miccoli <[email protected]>
#
# SPDX-License-Identifier: MIT

"""internal utility functions"""

import array
import errno
import fcntl
import os
import socket


def make_socket() -> socket.socket:
"""return a SOCK_DGRAM socket for communication with systemd"""
return socket.socket(family=socket.AF_UNIX, type=socket.SOCK_DGRAM)


def send_dgram_or_fd(sock: socket.socket, payload: bytes, address: str) -> None:
"""implement systemd logic: first try to send payload as a datagram,
if failed, retry sending as a mem_fd.
"""
retry_fd: bool
try:
nsent = sock.sendto(payload, address)
assert nsent == len(payload), f"Boundary broken? {nsent} != {len(payload)}"
retry_fd = False
except OSError as err:
if err.errno == errno.EMSGSIZE:
retry_fd = True
else:
raise
if retry_fd:
# send big payload as a memfd
fd = os.memfd_create(
"journal_entry", flags=os.MFD_CLOEXEC | os.MFD_ALLOW_SEALING
)
nwr = os.write(fd, payload)
assert nwr == len(
payload
), f"Unable to write to memfd: {nwr} != {len(payload)}"
# see https://github.com/systemd/systemd/issues/27608
fcntl.fcntl(
fd,
fcntl.F_ADD_SEALS,
fcntl.F_SEAL_SHRINK
| fcntl.F_SEAL_GROW
| fcntl.F_SEAL_WRITE
| fcntl.F_SEAL_SEAL,
)
_send_fds(sock=sock, buffers=[], fds=[fd], address=address)


def _send_fds(sock, buffers, fds, flags=0, address=None):
"""send_fds(sock, buffers, fds[, flags[, address]]) -> integer
Send the list of file descriptors fds over an AF_UNIX socket.
*** Patch to fix cpython bug GH-107898 ***
"""
return sock.sendmsg(
buffers,
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))],
flags,
address,
)
39 changes: 39 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import mmap
import random
import socket

import pytest

from trick17 import util


def test_make_socket():
with util.make_socket() as sock:
assert sock.fileno() != -1


@pytest.mark.skipif(
not hasattr(socket.socket, "sendmsg"),
reason="platform not supporting sendmsg",
)
def test_send(tmp_path):
sock_path = str(tmp_path / "socket")
with util.make_socket() as a, util.make_socket() as b:
a.bind(sock_path)

for nsend in (1, 2**17, 2**18):
out = random.getrandbits(nsend * 8).to_bytes(nsend, "little")

util.send_dgram_or_fd(b, out, sock_path)

msg, fds, msg_flags, _ = socket.recv_fds(a, len(out), 1)
assert (
msg_flags == 0
), f"Expecting 0, got msg_flags {socket.MsgFlag(msg_flags).name}"
if msg:
assert msg == out
assert len(fds) == 0
else:
assert len(fds) == 1
mm = mmap.mmap(fds[0], 0, flags=mmap.MAP_PRIVATE)
assert mm[:] == out

0 comments on commit d35f5b3

Please sign in to comment.