diff --git a/bumble/controller.py b/bumble/controller.py index eb202927..eff274a4 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -374,6 +374,12 @@ def find_classic_connection_by_handle(self, handle): return connection return None + def find_peripheral_connection_by_handle(self, handle): + for connection in self.peripheral_connections.values(): + if connection.handle == handle: + return connection + return None + def on_link_central_connected(self, central_address): ''' Called when an incoming connection occurs from a central on the link @@ -877,6 +883,14 @@ def on_hci_disconnect_command(self, command): else: # Remove the connection del self.central_connections[connection.peer_address] + elif connection := self.find_peripheral_connection_by_handle(handle): + if self.link: + self.link.disconnect( + connection.peer_address, self.random_address, command + ) + else: + # Remove the connection + del self.peripheral_connections[connection.peer_address] elif connection := self.find_classic_connection_by_handle(handle): if self.link: self.link.classic_disconnect( diff --git a/tests/connect_test.py b/tests/connect_test.py new file mode 100644 index 00000000..2f0da011 --- /dev/null +++ b/tests/connect_test.py @@ -0,0 +1,95 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import pytest +from unittest import mock + +from bumble.controller import Controller +from bumble.device import Connection, Device +from bumble.hci import HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR +from bumble.host import Host +from bumble.link import LocalLink + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def link() -> LocalLink: + return LocalLink() + + +@pytest.fixture +async def central_device(link) -> Device: + controller = Controller('Central', link=link) + host = Host() + host.controller = controller + device = Device(host=host) + await device.power_on() + return device + + +@pytest.fixture +async def peripheral_device(link) -> Device: + controller = Controller('Peripheral', link=link) + host = Host() + host.controller = controller + device = Device(host=host) + await device.power_on() + return device + + +async def connect(central_device, peripheral_device) -> Connection: + return await central_device.connect( + peripheral_device.host.controller.random_address + ) + + +@pytest.mark.asyncio +async def test_connect(central_device, peripheral_device): + conn = await connect(central_device, peripheral_device) + assert conn.self_address == central_device.host.controller.random_address + assert conn.peer_address == peripheral_device.host.controller.random_address + + +@pytest.fixture +async def connection(central_device, peripheral_device) -> Connection: + return await connect(central_device, peripheral_device) + + +@pytest.mark.asyncio +async def test_disconnect_from_central(central_device, peripheral_device, connection): + assert peripheral_device.connections + await asyncio.wait_for( + central_device.disconnect( + connection, reason=HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR + ), + timeout=1.0, + ) + assert not peripheral_device.connections + + +@pytest.mark.asyncio +async def test_disconnect_from_peripheral( + central_device, peripheral_device, connection +): + assert central_device.connections + await asyncio.wait_for( + peripheral_device.disconnect( + connection, reason=HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR + ), + timeout=1.0, + ) + assert not central_device.connections