Skip to content

Commit

Permalink
Add support for making connections over unix domain sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
steffen-kiess committed Dec 13, 2024
1 parent c9625a1 commit 5dba908
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 19 deletions.
86 changes: 86 additions & 0 deletions examples/example_telegram_monitor_unix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Example for the telegram monitor callback over unix domain socket."""

import asyncio
import getopt
import socket
import sys

from xknx import XKNX
from xknx.io import ConnectionConfig, ConnectionType
from xknx.telegram import AddressFilter, Telegram


def telegram_received_cb(telegram: Telegram) -> None:
"""Do something with the received telegram."""
print(f"Telegram received: {telegram}")


def show_help() -> None:
"""Print Help."""
print("Telegram filter.")
print("")
print("Usage:")
print("")
print(__file__, " Listen to all telegrams")
print(
__file__, "-f --filter 1/2/*,1/4/[5-6] Filter for specific group addresses"
)
print(
__file__, "-host hostname Connect to a specific host over ssh"
)
print(__file__, "-h --help Print help")
print("")


async def monitor(host, address_filters: list[AddressFilter] | None) -> None:
"""Set telegram_received_cb within XKNX and connect to KNX/IP device in daemon mode."""
if host is None:
connection_config = ConnectionConfig(
connection_type=ConnectionType.TUNNELING_TCP,
gateway_path="/run/knxnet",
)
else:

async def connect_ssh(loop, protocol_factory):
s1, s2 = socket.socketpair()

cmd = ["ssh", "--", host, "socat STDIO UNIX-CONNECT:/run/knxnet"]

await asyncio.create_subprocess_exec(*cmd, stdin=s2, stdout=s2)

return await loop.create_unix_connection(protocol_factory, sock=s1)

connection_config = ConnectionConfig(
connection_type=ConnectionType.TUNNELING_TCP,
connect_cb=connect_ssh,
)
xknx = XKNX(connection_config=connection_config, daemon_mode=True)
xknx.telegram_queue.register_telegram_received_cb(
telegram_received_cb, address_filters
)
await xknx.start()
await xknx.stop()


async def main(argv: list[str]) -> None:
"""Parse command line arguments and start monitor."""
try:
opts, _ = getopt.getopt(argv, "hf:", ["help", "filter=", "host="])
except getopt.GetoptError:
show_help()
sys.exit(2)
host = None
address_filters = None
for opt, arg in opts:
if opt in ["-h", "--help"]:
show_help()
sys.exit()
if opt in ["--host"]:
host = arg
if opt in ["-f", "--filter"]:
address_filters = list(map(AddressFilter, arg.split(",")))
await monitor(host, address_filters)


if __name__ == "__main__":
asyncio.run(main(sys.argv[1:]))
8 changes: 8 additions & 0 deletions test/io_tests/knxip_interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async def gateway_generator_mock(_):
start_tunnelling_tcp_mock.assert_called_once_with(
gateway_ip="10.1.0.0",
gateway_port=3671,
gateway_path=None,
connect_cb=None,
)

async def test_start_automatic_with_keyring_and_ia(self):
Expand Down Expand Up @@ -170,6 +172,8 @@ async def gateway_generator_mock(_):
start_tunnelling_tcp_mock.assert_called_once_with(
gateway_ip="10.1.0.0",
gateway_port=3671,
gateway_path=None,
connect_cb=None,
)

# IA not listed in keyring
Expand Down Expand Up @@ -240,6 +244,8 @@ async def test_start_tcp_tunnel_connection(self):
start_tunnelling_tcp.assert_called_once_with(
gateway_ip=gateway_ip,
gateway_port=3671,
gateway_path=None,
connect_cb=None,
)
with patch("xknx.io.tunnel.TCPTunnel.connect") as connect_tcp:
interface = knx_interface_factory(self.xknx, connection_config)
Expand Down Expand Up @@ -271,6 +277,8 @@ async def test_start_tcp_tunnel_connection_with_ia(self):
start_tunnelling_tcp.assert_called_once_with(
gateway_ip=gateway_ip,
gateway_port=3671,
gateway_path=None,
connect_cb=None,
)
with patch("xknx.io.tunnel.TCPTunnel.connect") as connect_tcp:
interface = knx_interface_factory(self.xknx, connection_config)
Expand Down
12 changes: 12 additions & 0 deletions xknx/io/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
from enum import Enum, auto
import os
from typing import Any
Expand Down Expand Up @@ -41,6 +43,8 @@ class ConnectionConfig:
* local_ip: Local ip or interface name though which xknx should connect.
* gateway_ip: IP or hostname of KNX/IP tunneling device.
* gateway_port: Port of KNX/IP tunneling device.
* gateway_path: Filename of unix domain socket of KNX/IP tunneling device.
* connect_cb: A callback which will be called every time a connection is created.
* route_back: For UDP TUNNELING connection.
The KNXnet/IP Server shall use the IP address and port in the received IP package
as the target IP address or port number for the response to the KNXnet/IP Client.
Expand All @@ -62,6 +66,12 @@ def __init__(
local_port: int = 0,
gateway_ip: str | None = None,
gateway_port: int = DEFAULT_MCAST_PORT,
gateway_path: str | None = None,
connect_cb: Callable[
[asyncio.AbstractEventLoop, Callable[[], asyncio.Protocol]],
Awaitable[tuple[asyncio.Transport, asyncio.Protocol]],
]
| None = None,
route_back: bool = False,
multicast_group: str = DEFAULT_MCAST_GRP,
multicast_port: int = DEFAULT_MCAST_PORT,
Expand All @@ -80,6 +90,8 @@ def __init__(
self.local_port = local_port
self.gateway_ip = gateway_ip
self.gateway_port = gateway_port
self.gateway_path = gateway_path
self.connect_cb = connect_cb
self.route_back = route_back
self.multicast_group = multicast_group
self.multicast_port = multicast_port
Expand Down
33 changes: 27 additions & 6 deletions xknx/io/knxip_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Awaitable
from collections.abc import Awaitable, Callable
import logging
import threading
from typing import TYPE_CHECKING, TypeVar
Expand Down Expand Up @@ -105,11 +105,17 @@ async def _start(self) -> None:
)
elif (
self.connection_config.connection_type == ConnectionType.TUNNELING_TCP
and gateway_ip is not None
and (
gateway_ip is not None
or self.connection_config.gateway_path is not None
or self.connection_config.connect_cb is not None
)
):
await self._start_tunnelling_tcp(
gateway_ip=gateway_ip,
gateway_port=self.connection_config.gateway_port,
gateway_path=self.connection_config.gateway_path,
connect_cb=self.connection_config.connect_cb,
)
elif (
self.connection_config.connection_type
Expand Down Expand Up @@ -172,6 +178,8 @@ async def _start_automatic(
await self._start_tunnelling_tcp(
gateway_ip=gateway.ip_addr,
gateway_port=gateway.port,
gateway_path=None,
connect_cb=None,
)
elif (
gateway.supports_tunnelling
Expand Down Expand Up @@ -202,16 +210,27 @@ async def _start_automatic(

async def _start_tunnelling_tcp(
self,
gateway_ip: str,
gateway_ip: str | None,
gateway_port: int,
gateway_path: str | None,
connect_cb: Callable[
[asyncio.AbstractEventLoop, Callable[[], asyncio.Protocol]],
Awaitable[tuple[asyncio.Transport, asyncio.Protocol]],
]
| None,
) -> None:
"""Start KNX/IP TCP tunnel."""
tunnel_address = self.connection_config.individual_address

if connect_cb is not None:
connect_info = "using connect callback"
elif gateway_path is not None:
connect_info = f"Unix Domain Socket {gateway_path}"
else:
connect_info = f"{gateway_ip}:{gateway_port} over TCP"
logger.debug(
"Starting tunnel to %s:%s over TCP%s",
gateway_ip,
gateway_port,
"Starting tunnel to %s%s",
connect_info,
f" requesting individual address {tunnel_address}"
if tunnel_address
else "",
Expand All @@ -220,6 +239,8 @@ async def _start_tunnelling_tcp(
self.xknx,
gateway_ip=gateway_ip,
gateway_port=gateway_port,
gateway_path=gateway_path,
connect_cb=connect_cb,
individual_address=tunnel_address,
cemi_received_callback=self.cemi_received,
auto_reconnect=self.connection_config.auto_reconnect,
Expand Down
24 changes: 18 additions & 6 deletions xknx/io/transport/tcp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
import logging

from xknx.exceptions import CommunicationError, CouldNotParseKNXIP, IncompleteKNXIPFrame
Expand Down Expand Up @@ -67,13 +67,19 @@ def __init__(
self,
remote_addr: tuple[str, int],
connection_lost_cb: Callable[[], None] | None = None,
connect_cb: Callable[
[asyncio.AbstractEventLoop, Callable[[], asyncio.Protocol]],
Awaitable[tuple[asyncio.Transport, asyncio.Protocol]],
]
| None = None,
):
"""Initialize TCPTransport class."""
self.remote_addr = remote_addr
self.remote_hpai = HPAI(*remote_addr, protocol=HostProtocol.IPV4_TCP)

self.callbacks = []
self._connection_lost_cb = connection_lost_cb
self._connect_cb = connect_cb
self.transport: asyncio.Transport | None = None
self._buffer = b""

Expand Down Expand Up @@ -117,11 +123,17 @@ async def connect(self) -> None:
connection_lost_callback=self._connection_lost,
)
loop = asyncio.get_running_loop()
(self.transport, _) = await loop.create_connection(
lambda: tcp_transport_factory,
host=self.remote_hpai.ip_addr,
port=self.remote_hpai.port,
)
if self._connect_cb is None:
(self.transport, _) = await loop.create_connection(
lambda: tcp_transport_factory,
host=self.remote_hpai.ip_addr,
port=self.remote_hpai.port,
)
else:
(self.transport, _) = await self._connect_cb(
loop,
lambda: tcp_transport_factory,
)

def _connection_lost(self) -> None:
"""Call assigned callback. Callback for connection lost."""
Expand Down
60 changes: 53 additions & 7 deletions xknx/io/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from abc import abstractmethod
import asyncio
from collections.abc import Awaitable, Callable
import logging
from typing import TYPE_CHECKING

from xknx.cemi import CEMIFrame
from xknx.core import XknxConnectionState, XknxConnectionType
from xknx.exceptions import CommunicationError, TunnellingAckError
from xknx.exceptions import CommunicationError, TunnellingAckError, XKNXException
from xknx.knxip import (
HPAI,
ConnectRequestInformation,
Expand Down Expand Up @@ -552,15 +553,38 @@ def __init__(
self,
xknx: XKNX,
cemi_received_callback: CEMIBytesCallbackType,
gateway_ip: str,
gateway_ip: str | None,
gateway_port: int,
gateway_path: str | None = None,
connect_cb: Callable[
[asyncio.AbstractEventLoop, Callable[[], asyncio.Protocol]],
Awaitable[tuple[asyncio.Transport, asyncio.Protocol]],
]
| None = None,
individual_address: IndividualAddress | None = None,
auto_reconnect: bool = True,
auto_reconnect_wait: int = 3,
):
"""Initialize Tunnel class."""

arg_count = (
(gateway_ip is not None)
+ (gateway_path is not None)
+ (connect_cb is not None)
)
if arg_count > 1:
raise XKNXException(
"Only one of gateway_ip, gateway_path and connect_cb may be set"
)
if arg_count == 0:
raise XKNXException(
"One of gateway_ip, gateway_path and connect_cb must be set"
)

self.gateway_ip = gateway_ip
self.gateway_port = gateway_port
self.gateway_path = gateway_path
self.connect_cb = connect_cb
super().__init__(
xknx=xknx,
cemi_received_callback=cemi_received_callback,
Expand All @@ -573,10 +597,29 @@ def __init__(

def _init_transport(self) -> None:
"""Initialize transport transport."""
self.transport = TCPTransport(
remote_addr=(self.gateway_ip, self.gateway_port),
connection_lost_cb=self._tunnel_lost,
)
if self.connect_cb is not None:
self.transport = TCPTransport(
remote_addr=("0.0.0.0", 0),
connection_lost_cb=self._tunnel_lost,
connect_cb=self.connect_cb,
)
elif self.gateway_path is not None:
self.transport = TCPTransport(
remote_addr=("0.0.0.0", 0),
connection_lost_cb=self._tunnel_lost,
connect_cb=lambda loop, protocol_factory: loop.create_unix_connection(
protocol_factory, path=self.gateway_path
),
)
elif self.gateway_ip is not None:
self.transport = TCPTransport(
remote_addr=(self.gateway_ip, self.gateway_port),
connection_lost_cb=self._tunnel_lost,
)
else:
raise XKNXException(
"One of gateway_ip, gateway_path and connect_cb must be set"
)

async def setup_tunnel(self) -> None:
"""Set up tunnel before sending a ConnectionRequest."""
Expand Down Expand Up @@ -619,8 +662,11 @@ def __init__(

def _init_transport(self) -> None:
"""Initialize transport transport."""
ip = self.gateway_ip
if ip is None:
ip = "unknown"
self.transport = SecureSession(
remote_addr=(self.gateway_ip, self.gateway_port),
remote_addr=(ip, self.gateway_port),
user_id=self._user_id,
user_password=self._user_password,
device_authentication_password=self._device_authentication_password,
Expand Down

0 comments on commit 5dba908

Please sign in to comment.