From 85b794d9dc96c02b5d03123240352cd32dcbd129 Mon Sep 17 00:00:00 2001 From: drc38 <20024196+drc38@users.noreply.github.com> Date: Thu, 26 Dec 2024 01:18:39 +1300 Subject: [PATCH] switch to new websockets asyncio (#1439) * switch to new websockets asyncio * add errors raised * fix error raise * extra detail on nosub test * add test for autoconfig * remove old tests * update deprecated enums in v16 tests * extra test for reboot required * update pre-commit * fix key --------- Co-authored-by: lbbrhzn <8673442+lbbrhzn@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- custom_components/ocpp/api.py | 56 +++++---- custom_components/ocpp/chargepoint.py | 12 +- custom_components/ocpp/manifest.json | 2 +- custom_components/ocpp/ocppv16.py | 4 +- custom_components/ocpp/ocppv201.py | 4 +- tests/charge_point_test.py | 7 +- tests/conftest.py | 8 +- tests/const.py | 1 + tests/test_charge_point_v16.py | 173 ++++++++++---------------- 10 files changed, 124 insertions(+), 145 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e658f46a..261aa8f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: # Run the formatter. - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v5.0.0 hooks: - id: check-executables-have-shebangs stages: [manual] diff --git a/custom_components/ocpp/api.py b/custom_components/ocpp/api.py index 9d23c8b9..f1f058cc 100644 --- a/custom_components/ocpp/api.py +++ b/custom_components/ocpp/api.py @@ -8,9 +8,9 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import STATE_OK from homeassistant.core import HomeAssistant -from websockets import Subprotocol -import websockets.protocol +from websockets import Subprotocol, NegotiationError import websockets.server +from websockets.asyncio.server import ServerConnection from .chargepoint import CentralSystemSettings from .ocppv16 import ChargePoint as ChargePointv16 @@ -21,7 +21,6 @@ CONF_CSID, CONF_HOST, CONF_PORT, - CONF_SKIP_SCHEMA_VALIDATION, CONF_SSL, CONF_SSL_CERTFILE_PATH, CONF_SSL_KEYFILE_PATH, @@ -34,7 +33,6 @@ DEFAULT_CSID, DEFAULT_HOST, DEFAULT_PORT, - DEFAULT_SKIP_SCHEMA_VALIDATION, DEFAULT_SSL, DEFAULT_SSL_CERTFILE_PATH, DEFAULT_SSL_KEYFILE_PATH, @@ -112,10 +110,11 @@ async def create(hass: HomeAssistant, entry: ConfigEntry): """Create instance and start listening for OCPP connections on given port.""" self = CentralSystem(hass, entry) - server = await websockets.server.serve( + server = await websockets.serve( self.on_connect, self.host, self.port, + select_subprotocol=self.select_subprotocol, subprotocols=self.subprotocols, ping_interval=None, # ping interval is not used here, because we send pings mamually in ChargePoint.monitor_connection() ping_timeout=None, @@ -125,27 +124,38 @@ async def create(hass: HomeAssistant, entry: ConfigEntry): self._server = server return self - async def on_connect(self, websocket: websockets.server.WebSocketServerProtocol): + def select_subprotocol( + self, connection: ServerConnection, subprotocols + ) -> Subprotocol | None: + """Override default subprotocol selection.""" + + # Server offers at least one subprotocol but client doesn't offer any. + # Default to None + if not subprotocols: + return None + + # Server and client both offer subprotocols. Look for a shared one. + proposed_subprotocols = set(subprotocols) + for subprotocol in proposed_subprotocols: + if subprotocol in self.subprotocols: + return subprotocol + + # No common subprotocol was found. + raise NegotiationError( + "invalid subprotocol; expected one of " + ", ".join(self.subprotocols) + ) + + async def on_connect(self, websocket: ServerConnection): """Request handler executed for every new OCPP connection.""" - if self.config.get(CONF_SKIP_SCHEMA_VALIDATION, DEFAULT_SKIP_SCHEMA_VALIDATION): - _LOGGER.warning("Skipping websocket subprotocol validation") + if websocket.subprotocol is not None: + _LOGGER.info("Websocket Subprotocol matched: %s", websocket.subprotocol) else: - if websocket.subprotocol is not None: - _LOGGER.info("Websocket Subprotocol matched: %s", websocket.subprotocol) - else: - # In the websockets lib if no subprotocols are supported by the - # client and the server, it proceeds without a subprotocol, - # so we have to manually close the connection. - _LOGGER.warning( - "Protocols mismatched | expected Subprotocols: %s," - " but client supports %s | Closing connection", - websocket.available_subprotocols, - websocket.request_headers.get("Sec-WebSocket-Protocol", ""), - ) - return await websocket.close() + _LOGGER.info( + "Websocket Subprotocol not provided by charger: default to ocpp1.6" + ) - _LOGGER.info(f"Charger websocket path={websocket.path}") - cp_id = websocket.path.strip("/") + _LOGGER.info(f"Charger websocket path={websocket.request.path}") + cp_id = websocket.request.path.strip("/") cp_id = cp_id[cp_id.rfind("/") + 1 :] if self.settings.cpid not in self.charge_points: _LOGGER.info(f"Charger {cp_id} connected to {self.host}:{self.port}.") diff --git a/custom_components/ocpp/chargepoint.py b/custom_components/ocpp/chargepoint.py index 2655fd11..7fd22e3f 100644 --- a/custom_components/ocpp/chargepoint.py +++ b/custom_components/ocpp/chargepoint.py @@ -21,7 +21,9 @@ from homeassistant.helpers import device_registry, entity_component, entity_registry import homeassistant.helpers.config_validation as cv import voluptuous as vol -import websockets.server +from websockets.asyncio.server import ServerConnection +from websockets.exceptions import WebSocketException +from websockets.protocol import State from ocpp.charge_point import ChargePoint as cp from ocpp.v16 import call as callv16 @@ -471,7 +473,7 @@ async def monitor_connection(self): self._metrics[cstat.latency_pong.value].unit = "ms" connection = self._connection timeout_counter = 0 - while connection.open: + while connection.state is State.OPEN: try: await asyncio.sleep(self.central.websocket_ping_interval) time0 = time.perf_counter() @@ -529,7 +531,7 @@ async def run(self, tasks): await asyncio.gather(*self.tasks) except TimeoutError: pass - except websockets.exceptions.WebSocketException as websocket_exception: + except WebSocketException as websocket_exception: _LOGGER.debug(f"Connection closed to '{self.id}': {websocket_exception}") except Exception as other_exception: _LOGGER.error( @@ -542,13 +544,13 @@ async def run(self, tasks): async def stop(self): """Close connection and cancel ongoing tasks.""" self.status = STATE_UNAVAILABLE - if self._connection.open: + if self._connection.state is State.OPEN: _LOGGER.debug(f"Closing websocket to '{self.id}'") await self._connection.close() for task in self.tasks: task.cancel() - async def reconnect(self, connection: websockets.server.WebSocketServerProtocol): + async def reconnect(self, connection: ServerConnection): """Reconnect charge point.""" _LOGGER.debug(f"Reconnect websocket to {self.id}") diff --git a/custom_components/ocpp/manifest.json b/custom_components/ocpp/manifest.json index 9f826d71..f0f1ef50 100644 --- a/custom_components/ocpp/manifest.json +++ b/custom_components/ocpp/manifest.json @@ -14,7 +14,7 @@ "issue_tracker": "https://github.com/lbbrhzn/ocpp/issues", "requirements": [ "ocpp>=1.0.0", - "websockets>=12.0" + "websockets>=13.1" ], "version": "0.6.1" } diff --git a/custom_components/ocpp/ocppv16.py b/custom_components/ocpp/ocppv16.py index 4efc7569..b49ecf32 100644 --- a/custom_components/ocpp/ocppv16.py +++ b/custom_components/ocpp/ocppv16.py @@ -9,7 +9,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant import voluptuous as vol -import websockets.server +from websockets.asyncio.server import ServerConnection from ocpp.routing import on from ocpp.v16 import call, call_result @@ -74,7 +74,7 @@ class ChargePoint(cp): def __init__( self, id: str, - connection: websockets.server.WebSocketServerProtocol, + connection: ServerConnection, hass: HomeAssistant, entry: ConfigEntry, central: CentralSystemSettings, diff --git a/custom_components/ocpp/ocppv201.py b/custom_components/ocpp/ocppv201.py index f1ea7bd3..3114984f 100644 --- a/custom_components/ocpp/ocppv201.py +++ b/custom_components/ocpp/ocppv201.py @@ -11,7 +11,7 @@ from homeassistant.const import UnitOfTime from homeassistant.core import HomeAssistant, SupportsResponse, ServiceResponse from homeassistant.exceptions import ServiceValidationError, HomeAssistantError -import websockets.server +from websockets.asyncio.server import ServerConnection from ocpp.routing import on from ocpp.v201 import call, call_result @@ -85,7 +85,7 @@ class ChargePoint(cp): def __init__( self, id: str, - connection: websockets.server.WebSocketServerProtocol, + connection: ServerConnection, hass: HomeAssistant, entry: ConfigEntry, central: CentralSystemSettings, diff --git a/tests/charge_point_test.py b/tests/charge_point_test.py index 02bea1ef..0398c276 100644 --- a/tests/charge_point_test.py +++ b/tests/charge_point_test.py @@ -23,7 +23,8 @@ from pytest_homeassistant_custom_component.common import MockConfigEntry from typing import Any from collections.abc import Callable, Awaitable -import websockets +from websockets import connect +from websockets.asyncio.client import ClientConnection async def set_switch(hass: HomeAssistant, cs: CentralSystem, key: str, on: bool): @@ -108,12 +109,12 @@ async def run_charge_point_test( config_entry: MockConfigEntry, identity: str, subprotocols: list[str] | None, - charge_point: Callable[[websockets.WebSocketClientProtocol], ChargePoint], + charge_point: Callable[[ClientConnection], ChargePoint], parallel_tests: list[Callable[[ChargePoint], Awaitable]], ) -> Any: """Connect web socket client to the CSMS and run a number of tests in parallel.""" completed: list[list[bool]] = [[] for _ in parallel_tests] - async with websockets.connect( + async with connect( f"ws://127.0.0.1:{config_entry.data[CONF_PORT]}/{identity}", subprotocols=[Subprotocol(s) for s in subprotocols] if subprotocols is not None diff --git a/tests/conftest.py b/tests/conftest.py index 4264d5c2..9ce3d091 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,11 +35,11 @@ def skip_notifications_fixture(): def bypass_get_data_fixture(): """Skip calls to get data from API.""" future = asyncio.Future() - future.set_result(websockets.WebSocketServer) + future.set_result(websockets.asyncio.server.Server) with ( - patch("websockets.server.serve", return_value=future), - patch("websockets.server.WebSocketServer.close"), - patch("websockets.server.WebSocketServer.wait_closed"), + patch("websockets.asyncio.server.serve", return_value=future), + patch("websockets.asyncio.server.Server.close"), + patch("websockets.asyncio.server.Server.wait_closed"), ): yield diff --git a/tests/const.py b/tests/const.py index e66520be..69897707 100644 --- a/tests/const.py +++ b/tests/const.py @@ -76,6 +76,7 @@ CONF_PORT: 9002, CONF_CPID: "test_cpid_2", CONF_SKIP_SCHEMA_VALIDATION: True, + CONF_MONITORED_VARIABLES_AUTOCONFIG: False, } # separate entry for switch so tests can run concurrently diff --git a/tests/test_charge_point_v16.py b/tests/test_charge_point_v16.py index ffc60a23..3e63c39e 100644 --- a/tests/test_charge_point_v16.py +++ b/tests/test_charge_point_v16.py @@ -117,96 +117,68 @@ async def test_services(hass, cs, socket_enabled): await set_number(hass, cs, number.key, 10) # Test MOCK_CONFIG_DATA_2 - if True: - # Create a mock entry so we don't have to go through config flow - config_entry2 = MockConfigEntry( - domain=OCPP_DOMAIN, - data=MOCK_CONFIG_DATA_2, - entry_id="test_cms2", - title="test_cms2", - ) - config_entry2.add_to_hass(hass) - assert await hass.config_entries.async_setup(config_entry2.entry_id) - await hass.async_block_till_done() - - # no subprotocol - # NB each new config entry will trigger async_update_entry - # if the charger measurands differ from the config entry - # which causes the websocket server to close/restart with a - # ConnectionClosedOK exception, hence it needs to be passed/suppressed - async with websockets.connect( - "ws://127.0.0.1:9002/CP_1_nosub", - ) as ws2: - # use a different id for debugging - cp2 = ChargePoint("CP_1_no_subprotocol", ws2) - with contextlib.suppress( - asyncio.TimeoutError, websockets.exceptions.ConnectionClosedOK - ): - await asyncio.wait_for( - asyncio.gather( - cp2.start(), - cp2.send_boot_notification(), - cp2.send_authorize(), - cp2.send_heartbeat(), - cp2.send_status_notification(), - cp2.send_firmware_status(), - cp2.send_data_transfer(), - cp2.send_start_transaction(), - cp2.send_stop_transaction(), - cp2.send_meter_periodic_data(), - ), - timeout=5, - ) - await ws2.close() - await asyncio.sleep(1) - if entry := hass.config_entries.async_get_entry(config_entry2.entry_id): - await hass.config_entries.async_remove(entry.entry_id) - await hass.async_block_till_done() - # Create a mock entry so we don't have to go through config flow - config_entry = MockConfigEntry( - domain=OCPP_DOMAIN, data=MOCK_CONFIG_DATA, entry_id="test_cms", title="test_cms" + config_entry2 = MockConfigEntry( + domain=OCPP_DOMAIN, + data=MOCK_CONFIG_DATA_2, + entry_id="test_cms2", + title="test_cms2", ) - config_entry.add_to_hass(hass) - assert await hass.config_entries.async_setup(config_entry.entry_id) + config_entry2.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry2.entry_id) await hass.async_block_till_done() - cs = hass.data[OCPP_DOMAIN][config_entry.entry_id] - - # no subprotocol + # no subprotocol central system assumes ocpp1.6 charge point + # NB each new config entry will trigger async_update_entry + # if the charger measurands differ from the config entry + # which causes the websocket server to close/restart with a + # ConnectionClosedOK exception, hence it needs to be passed/suppressed async with websockets.connect( - "ws://127.0.0.1:9000/CP_1_unsup", - ) as ws: + "ws://127.0.0.1:9002/CP_1_nosub", + ) as ws2: # use a different id for debugging - cp = ChargePoint("CP_1_no_subprotocol", ws) - with contextlib.suppress(websockets.exceptions.ConnectionClosedOK): + assert ws2.subprotocol is None + cp2 = ChargePoint("CP_1_no_subprotocol", ws2) + with contextlib.suppress( + asyncio.TimeoutError, websockets.exceptions.ConnectionClosedOK + ): await asyncio.wait_for( asyncio.gather( - cp.start(), + cp2.start(), + cp2.send_boot_notification(), + cp2.send_authorize(), + cp2.send_heartbeat(), + cp2.send_status_notification(), + cp2.send_firmware_status(), + cp2.send_data_transfer(), + cp2.send_start_transaction(), + cp2.send_stop_transaction(), + cp2.send_meter_periodic_data(), ), - timeout=3, + timeout=5, ) - await ws.close() - + await ws2.close() await asyncio.sleep(1) + if entry := hass.config_entries.async_get_entry(config_entry2.entry_id): + await hass.config_entries.async_remove(entry.entry_id) + await hass.async_block_till_done() - # unsupported subprotocol - async with websockets.connect( - "ws://127.0.0.1:9000/CP_1_unsup", - subprotocols=["ocpp0.0"], - ) as ws: - # use a different id for debugging - cp = ChargePoint("CP_1_unsupported_subprotocol", ws) - with contextlib.suppress(websockets.exceptions.ConnectionClosedOK): - await asyncio.wait_for( - asyncio.gather( - cp.start(), - ), - timeout=3, - ) - await ws.close() + # Create a mock entry so we don't have to go through config flow + config_entry = MockConfigEntry( + domain=OCPP_DOMAIN, data=MOCK_CONFIG_DATA, entry_id="test_cms", title="test_cms" + ) + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() - await asyncio.sleep(1) + cs = hass.data[OCPP_DOMAIN][config_entry.entry_id] + + # unsupported subprotocol raises websockets exception + with pytest.raises(websockets.exceptions.InvalidStatus): + await websockets.connect( + "ws://127.0.0.1:9000/CP_1_unsup", + subprotocols=["ocpp0.0"], + ) # test restore feature of meter_start and active_tranasction_id. async with websockets.connect( @@ -431,18 +403,6 @@ async def test_services(hass, cs, socket_enabled): await asyncio.sleep(1) - # setting state no longer available with websockets >14 - # test ping timeout, change cpid to start new connection - # cs.settings.cpid = "CP_3_test" - # async with websockets.connect( - # "ws://127.0.0.1:9000/CP_3", - # subprotocols=["ocpp1.6"], - # ) as ws: - # cp = ChargePoint("CP_3_test", ws) - # ws.state = 3 # CLOSED = 3 - # await asyncio.sleep(3) - # await ws.close() - # test services when charger is unavailable await asyncio.sleep(1) await test_services(hass, cs, socket_enabled) @@ -460,7 +420,7 @@ def __init__(self, id, connection, response_timeout=30): self.active_transactionId: int = 0 self.accept: bool = True - @on(Action.GetConfiguration) + @on(Action.get_configuration) def on_get_configuration(self, key, **kwargs): """Handle a get configuration requests.""" if key[0] == ConfigurationKey.supported_feature_profiles.value: @@ -547,15 +507,20 @@ def on_get_configuration(self, key, **kwargs): configuration_key=[{"key": key[0], "readonly": False, "value": ""}] ) - @on(Action.ChangeConfiguration) - def on_change_configuration(self, **kwargs): + @on(Action.change_configuration) + def on_change_configuration(self, key, **kwargs): """Handle a get configuration request.""" if self.accept is True: - return call_result.ChangeConfiguration(ConfigurationStatus.accepted) + if key == ConfigurationKey.meter_values_sampled_data.value: + return call_result.ChangeConfiguration( + ConfigurationStatus.reboot_required + ) + else: + return call_result.ChangeConfiguration(ConfigurationStatus.accepted) else: return call_result.ChangeConfiguration(ConfigurationStatus.rejected) - @on(Action.ChangeAvailability) + @on(Action.change_availability) def on_change_availability(self, **kwargs): """Handle change availability request.""" if self.accept is True: @@ -563,7 +528,7 @@ def on_change_availability(self, **kwargs): else: return call_result.ChangeAvailability(AvailabilityStatus.rejected) - @on(Action.UnlockConnector) + @on(Action.unlock_connector) def on_unlock_connector(self, **kwargs): """Handle unlock request.""" if self.accept is True: @@ -571,7 +536,7 @@ def on_unlock_connector(self, **kwargs): else: return call_result.UnlockConnector(UnlockStatus.unlock_failed) - @on(Action.Reset) + @on(Action.reset) def on_reset(self, **kwargs): """Handle change availability request.""" if self.accept is True: @@ -579,7 +544,7 @@ def on_reset(self, **kwargs): else: return call_result.Reset(ResetStatus.rejected) - @on(Action.RemoteStartTransaction) + @on(Action.remote_start_transaction) def on_remote_start_transaction(self, **kwargs): """Handle remote start request.""" if self.accept is True: @@ -588,7 +553,7 @@ def on_remote_start_transaction(self, **kwargs): else: return call_result.RemoteStopTransaction(RemoteStartStopStatus.rejected) - @on(Action.RemoteStopTransaction) + @on(Action.remote_stop_transaction) def on_remote_stop_transaction(self, **kwargs): """Handle remote stop request.""" if self.accept is True: @@ -596,7 +561,7 @@ def on_remote_stop_transaction(self, **kwargs): else: return call_result.RemoteStopTransaction(RemoteStartStopStatus.rejected) - @on(Action.SetChargingProfile) + @on(Action.set_charging_profile) def on_set_charging_profile(self, **kwargs): """Handle set charging profile request.""" if self.accept is True: @@ -604,7 +569,7 @@ def on_set_charging_profile(self, **kwargs): else: return call_result.SetChargingProfile(ChargingProfileStatus.rejected) - @on(Action.ClearChargingProfile) + @on(Action.clear_charging_profile) def on_clear_charging_profile(self, **kwargs): """Handle clear charging profile request.""" if self.accept is True: @@ -612,7 +577,7 @@ def on_clear_charging_profile(self, **kwargs): else: return call_result.ClearChargingProfile(ClearChargingProfileStatus.unknown) - @on(Action.TriggerMessage) + @on(Action.trigger_message) def on_trigger_message(self, **kwargs): """Handle trigger message request.""" if self.accept is True: @@ -620,17 +585,17 @@ def on_trigger_message(self, **kwargs): else: return call_result.TriggerMessage(TriggerMessageStatus.rejected) - @on(Action.UpdateFirmware) + @on(Action.update_firmware) def on_update_firmware(self, **kwargs): """Handle update firmware request.""" return call_result.UpdateFirmware() - @on(Action.GetDiagnostics) + @on(Action.get_diagnostics) def on_get_diagnostics(self, **kwargs): """Handle get diagnostics request.""" return call_result.GetDiagnostics() - @on(Action.DataTransfer) + @on(Action.data_transfer) def on_data_transfer(self, **kwargs): """Handle get data transfer request.""" if self.accept is True: