From 04aaa50424ae4b99ec53fa237bdb3f69d262064a Mon Sep 17 00:00:00 2001 From: Aaron Bach Date: Mon, 22 Jan 2024 18:16:03 -0700 Subject: [PATCH] Fix config error when ports are passed as strings (#865) --- ecowitt2mqtt/config.py | 30 ++++++++++++++++++++++++++---- tests/test_config.py | 21 +++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/ecowitt2mqtt/config.py b/ecowitt2mqtt/config.py index d38dd7d8..af2ce185 100644 --- a/ecowitt2mqtt/config.py +++ b/ecowitt2mqtt/config.py @@ -4,7 +4,7 @@ import os from collections.abc import Generator from numbers import Number -from typing import Annotated, Any +from typing import Any from uuid import uuid4 from pydantic import ( @@ -92,6 +92,24 @@ def validate_boolean(value: bool | str | Number) -> bool: raise ValueError(f"invalid boolean value: {value}") +def validate_port(value: int | str) -> int: + """Validate a port + + Args: + value: A value to validate. + + Returns: + An validated port. + + Raises: + ValueError: Raises if the value is not a valid port. + """ + parsed = int(value) + if not 1 <= parsed <= 65536: + raise ValueError(f"invalid port: {value}") + return parsed + + class Config(BaseModel): """Define a config object.""" @@ -102,7 +120,7 @@ class Config(BaseModel): # Optional MQTT parameters: mqtt_password: str | None = None - mqtt_port: Annotated[int, Field(strict=True, ge=1, le=65536)] = DEFAULT_MQTT_PORT + mqtt_port: int = DEFAULT_MQTT_PORT mqtt_retain: bool = False mqtt_tls: bool = False mqtt_topic: str | None = None @@ -126,7 +144,7 @@ class Config(BaseModel): # Optional HTTP parameters: endpoint: str = DEFAULT_ENDPOINT - port: Annotated[int, Field(strict=True, ge=1, le=65536)] = DEFAULT_PORT + port: int = DEFAULT_PORT # Optional logging parameters: diagnostics: bool = False @@ -240,7 +258,7 @@ def validate_boolean_battery_true_value(cls, value: int | str) -> int: "disable_calculated_data", mode="before" )(validate_boolean) - validate_has_discovery = field_validator("hass_discovery", mode="before")( + validate_hass_discovery = field_validator("hass_discovery", mode="before")( validate_boolean ) @@ -264,12 +282,16 @@ def validate_mqtt_auth(cls, data: dict[str, Any]) -> dict[str, Any]: raise ValueError("Invalid MQTT auth configuration") return data + validate_mqtt_port = field_validator("mqtt_port", mode="before")(validate_port) + validate_mqtt_retain = field_validator("mqtt_retain", mode="before")( validate_boolean ) validate_mqtt_tls = field_validator("mqtt_tls", mode="before")(validate_boolean) + validate_port = field_validator("port", mode="before")(validate_port) + validate_raw_data = field_validator("raw_data", mode="before")(validate_boolean) @model_validator(mode="before") diff --git a/tests/test_config.py b/tests/test_config.py index 7c66a903..134d7c7c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -470,6 +470,27 @@ def test_precision() -> None: assert configs.default_config.precision == 2 +@pytest.mark.parametrize( + "config,is_valid", + ( + [TEST_CONFIG_JSON | {CONF_PORT: 1883}, True], + [TEST_CONFIG_JSON | {CONF_PORT: 9000}, True], + [TEST_CONFIG_JSON | {CONF_PORT: "1883"}, True], + [TEST_CONFIG_JSON | {CONF_PORT: "Not a port"}, False], + ), +) +def test_port(config: dict[str, Any], is_valid: bool) -> None: + """Test validating a port. + + Args: + config: A configuration dictionary. + is_valid: Whether the configuration is valid. + """ + if not is_valid: + with pytest.raises(ConfigError): + _ = Configs(config) + + @pytest.mark.parametrize( "config,verbose_value", [