Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for making connections over unix domain sockets #1620

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ nav_order: 2
- Fix typo in management procedure (`nm_invididual_address_write` was renamed to `nm_individual_address_write`)
- Fix TunnellingFeatureResponse missing `return_code`

### Connection

- Add support for making connections over unix domain sockets.

# 3.3.0 Climate humidity 2024-10-20

### Devices
Expand Down
86 changes: 86 additions & 0 deletions examples/example_telegram_monitor_unix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Example for the telegram monitor callback over unix domain socket."""

import asyncio
import getopt
import socket
import sys

from xknx import XKNX
from xknx.io import ConnectionConfig, ConnectionType
from xknx.telegram import AddressFilter, Telegram


def telegram_received_cb(telegram: Telegram) -> None:
"""Do something with the received telegram."""
print(f"Telegram received: {telegram}")


def show_help() -> None:
"""Print Help."""
print("Telegram filter.")
print("")
print("Usage:")
print("")
print(__file__, " Listen to all telegrams")
print(
__file__, "-f --filter 1/2/*,1/4/[5-6] Filter for specific group addresses"
)
print(
__file__, "-host hostname Connect to a specific host over ssh"
)
print(__file__, "-h --help Print help")
print("")


async def monitor(host, address_filters: list[AddressFilter] | None) -> None:
"""Set telegram_received_cb within XKNX and connect to KNX/IP device in daemon mode."""
if host is None:
connection_config = ConnectionConfig(
connection_type=ConnectionType.TUNNELING_TCP,
gateway_path="/run/knxnet",
)
else:

async def connect_ssh(loop, protocol_factory):
s1, s2 = socket.socketpair()

cmd = ["ssh", "--", host, "socat STDIO UNIX-CONNECT:/run/knxnet"]

await asyncio.create_subprocess_exec(*cmd, stdin=s2, stdout=s2)

return await loop.create_unix_connection(protocol_factory, sock=s1)

connection_config = ConnectionConfig(
connection_type=ConnectionType.TUNNELING_TCP,
connect_cb=connect_ssh,
)
xknx = XKNX(connection_config=connection_config, daemon_mode=True)
xknx.telegram_queue.register_telegram_received_cb(
telegram_received_cb, address_filters
)
await xknx.start()
await xknx.stop()


async def main(argv: list[str]) -> None:
"""Parse command line arguments and start monitor."""
try:
opts, _ = getopt.getopt(argv, "hf:", ["help", "filter=", "host="])
except getopt.GetoptError:
show_help()
sys.exit(2)
host = None
address_filters = None
for opt, arg in opts:
if opt in ["-h", "--help"]:
show_help()
sys.exit()
if opt in ["--host"]:
host = arg
if opt in ["-f", "--filter"]:
address_filters = list(map(AddressFilter, arg.split(",")))
await monitor(host, address_filters)


if __name__ == "__main__":
asyncio.run(main(sys.argv[1:]))
8 changes: 8 additions & 0 deletions test/io_tests/knxip_interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async def gateway_generator_mock(_):
start_tunnelling_tcp_mock.assert_called_once_with(
gateway_ip="10.1.0.0",
gateway_port=3671,
gateway_path=None,
connect_cb=None,
)

async def test_start_automatic_with_keyring_and_ia(self):
Expand Down Expand Up @@ -170,6 +172,8 @@ async def gateway_generator_mock(_):
start_tunnelling_tcp_mock.assert_called_once_with(
gateway_ip="10.1.0.0",
gateway_port=3671,
gateway_path=None,
connect_cb=None,
)

# IA not listed in keyring
Expand Down Expand Up @@ -240,6 +244,8 @@ async def test_start_tcp_tunnel_connection(self):
start_tunnelling_tcp.assert_called_once_with(
gateway_ip=gateway_ip,
gateway_port=3671,
gateway_path=None,
connect_cb=None,
)
with patch("xknx.io.tunnel.TCPTunnel.connect") as connect_tcp:
interface = knx_interface_factory(self.xknx, connection_config)
Expand Down Expand Up @@ -271,6 +277,8 @@ async def test_start_tcp_tunnel_connection_with_ia(self):
start_tunnelling_tcp.assert_called_once_with(
gateway_ip=gateway_ip,
gateway_port=3671,
gateway_path=None,
connect_cb=None,
)
with patch("xknx.io.tunnel.TCPTunnel.connect") as connect_tcp:
interface = knx_interface_factory(self.xknx, connection_config)
Expand Down
12 changes: 12 additions & 0 deletions xknx/io/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
from enum import Enum, auto
import os
from typing import Any
Expand Down Expand Up @@ -41,6 +43,8 @@ class ConnectionConfig:
* local_ip: Local ip or interface name though which xknx should connect.
* gateway_ip: IP or hostname of KNX/IP tunneling device.
* gateway_port: Port of KNX/IP tunneling device.
* gateway_path: Filename of unix domain socket of KNX/IP tunneling device.
* connect_cb: A callback which will be called every time a connection is created.
* route_back: For UDP TUNNELING connection.
The KNXnet/IP Server shall use the IP address and port in the received IP package
as the target IP address or port number for the response to the KNXnet/IP Client.
Expand All @@ -62,6 +66,12 @@ def __init__(
local_port: int = 0,
gateway_ip: str | None = None,
gateway_port: int = DEFAULT_MCAST_PORT,
gateway_path: str | None = None,
connect_cb: Callable[
[asyncio.AbstractEventLoop, Callable[[], asyncio.Protocol]],
Awaitable[tuple[asyncio.Transport, asyncio.Protocol]],
]
| None = None,
route_back: bool = False,
multicast_group: str = DEFAULT_MCAST_GRP,
multicast_port: int = DEFAULT_MCAST_PORT,
Expand All @@ -80,6 +90,8 @@ def __init__(
self.local_port = local_port
self.gateway_ip = gateway_ip
self.gateway_port = gateway_port
self.gateway_path = gateway_path
self.connect_cb = connect_cb
self.route_back = route_back
self.multicast_group = multicast_group
self.multicast_port = multicast_port
Expand Down
33 changes: 27 additions & 6 deletions xknx/io/knxip_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Awaitable
from collections.abc import Awaitable, Callable
import logging
import threading
from typing import TYPE_CHECKING, TypeVar
Expand Down Expand Up @@ -105,11 +105,17 @@
)
elif (
self.connection_config.connection_type == ConnectionType.TUNNELING_TCP
and gateway_ip is not None
and (
gateway_ip is not None
or self.connection_config.gateway_path is not None
or self.connection_config.connect_cb is not None
)
):
await self._start_tunnelling_tcp(
gateway_ip=gateway_ip,
gateway_port=self.connection_config.gateway_port,
gateway_path=self.connection_config.gateway_path,
connect_cb=self.connection_config.connect_cb,
)
elif (
self.connection_config.connection_type
Expand Down Expand Up @@ -172,6 +178,8 @@
await self._start_tunnelling_tcp(
gateway_ip=gateway.ip_addr,
gateway_port=gateway.port,
gateway_path=None,
connect_cb=None,
)
elif (
gateway.supports_tunnelling
Expand Down Expand Up @@ -202,16 +210,27 @@

async def _start_tunnelling_tcp(
self,
gateway_ip: str,
gateway_ip: str | None,
gateway_port: int,
gateway_path: str | None,
connect_cb: Callable[
[asyncio.AbstractEventLoop, Callable[[], asyncio.Protocol]],
Awaitable[tuple[asyncio.Transport, asyncio.Protocol]],
]
| None,
) -> None:
"""Start KNX/IP TCP tunnel."""
tunnel_address = self.connection_config.individual_address

if connect_cb is not None:
connect_info = "using connect callback"

Check warning on line 226 in xknx/io/knxip_interface.py

View check run for this annotation

Codecov / codecov/patch

xknx/io/knxip_interface.py#L226

Added line #L226 was not covered by tests
elif gateway_path is not None:
connect_info = f"Unix Domain Socket {gateway_path}"

Check warning on line 228 in xknx/io/knxip_interface.py

View check run for this annotation

Codecov / codecov/patch

xknx/io/knxip_interface.py#L228

Added line #L228 was not covered by tests
else:
connect_info = f"{gateway_ip}:{gateway_port} over TCP"
logger.debug(
"Starting tunnel to %s:%s over TCP%s",
gateway_ip,
gateway_port,
"Starting tunnel to %s%s",
connect_info,
f" requesting individual address {tunnel_address}"
if tunnel_address
else "",
Expand All @@ -220,6 +239,8 @@
self.xknx,
gateway_ip=gateway_ip,
gateway_port=gateway_port,
gateway_path=gateway_path,
connect_cb=connect_cb,
individual_address=tunnel_address,
cemi_received_callback=self.cemi_received,
auto_reconnect=self.connection_config.auto_reconnect,
Expand Down
24 changes: 18 additions & 6 deletions xknx/io/transport/tcp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
import logging

from xknx.exceptions import CommunicationError, CouldNotParseKNXIP, IncompleteKNXIPFrame
Expand Down Expand Up @@ -67,13 +67,19 @@
self,
remote_addr: tuple[str, int],
connection_lost_cb: Callable[[], None] | None = None,
connect_cb: Callable[
[asyncio.AbstractEventLoop, Callable[[], asyncio.Protocol]],
Awaitable[tuple[asyncio.Transport, asyncio.Protocol]],
]
| None = None,
):
"""Initialize TCPTransport class."""
self.remote_addr = remote_addr
self.remote_hpai = HPAI(*remote_addr, protocol=HostProtocol.IPV4_TCP)

self.callbacks = []
self._connection_lost_cb = connection_lost_cb
self._connect_cb = connect_cb
self.transport: asyncio.Transport | None = None
self._buffer = b""

Expand Down Expand Up @@ -117,11 +123,17 @@
connection_lost_callback=self._connection_lost,
)
loop = asyncio.get_running_loop()
(self.transport, _) = await loop.create_connection(
lambda: tcp_transport_factory,
host=self.remote_hpai.ip_addr,
port=self.remote_hpai.port,
)
if self._connect_cb is None:
(self.transport, _) = await loop.create_connection(

Check warning on line 127 in xknx/io/transport/tcp_transport.py

View check run for this annotation

Codecov / codecov/patch

xknx/io/transport/tcp_transport.py#L126-L127

Added lines #L126 - L127 were not covered by tests
lambda: tcp_transport_factory,
host=self.remote_hpai.ip_addr,
port=self.remote_hpai.port,
)
else:
(self.transport, _) = await self._connect_cb(

Check warning on line 133 in xknx/io/transport/tcp_transport.py

View check run for this annotation

Codecov / codecov/patch

xknx/io/transport/tcp_transport.py#L133

Added line #L133 was not covered by tests
loop,
lambda: tcp_transport_factory,
)

def _connection_lost(self) -> None:
"""Call assigned callback. Callback for connection lost."""
Expand Down
Loading
Loading