Skip to content

Commit

Permalink
Fix config error when ports are passed as strings (#865)
Browse files Browse the repository at this point in the history
  • Loading branch information
bachya authored Jan 23, 2024
1 parent 5df7392 commit 04aaa50
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
30 changes: 26 additions & 4 deletions ecowitt2mqtt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from collections.abc import Generator
from numbers import Number
from typing import Annotated, Any
from typing import Any
from uuid import uuid4

from pydantic import (
Expand Down Expand Up @@ -92,6 +92,24 @@ def validate_boolean(value: bool | str | Number) -> bool:
raise ValueError(f"invalid boolean value: {value}")


def validate_port(value: int | str) -> int:
"""Validate a port
Args:
value: A value to validate.
Returns:
An validated port.
Raises:
ValueError: Raises if the value is not a valid port.
"""
parsed = int(value)
if not 1 <= parsed <= 65536:
raise ValueError(f"invalid port: {value}")
return parsed


class Config(BaseModel):
"""Define a config object."""

Expand All @@ -102,7 +120,7 @@ class Config(BaseModel):

# Optional MQTT parameters:
mqtt_password: str | None = None
mqtt_port: Annotated[int, Field(strict=True, ge=1, le=65536)] = DEFAULT_MQTT_PORT
mqtt_port: int = DEFAULT_MQTT_PORT
mqtt_retain: bool = False
mqtt_tls: bool = False
mqtt_topic: str | None = None
Expand All @@ -126,7 +144,7 @@ class Config(BaseModel):

# Optional HTTP parameters:
endpoint: str = DEFAULT_ENDPOINT
port: Annotated[int, Field(strict=True, ge=1, le=65536)] = DEFAULT_PORT
port: int = DEFAULT_PORT

# Optional logging parameters:
diagnostics: bool = False
Expand Down Expand Up @@ -240,7 +258,7 @@ def validate_boolean_battery_true_value(cls, value: int | str) -> int:
"disable_calculated_data", mode="before"
)(validate_boolean)

validate_has_discovery = field_validator("hass_discovery", mode="before")(
validate_hass_discovery = field_validator("hass_discovery", mode="before")(
validate_boolean
)

Expand All @@ -264,12 +282,16 @@ def validate_mqtt_auth(cls, data: dict[str, Any]) -> dict[str, Any]:
raise ValueError("Invalid MQTT auth configuration")
return data

validate_mqtt_port = field_validator("mqtt_port", mode="before")(validate_port)

validate_mqtt_retain = field_validator("mqtt_retain", mode="before")(
validate_boolean
)

validate_mqtt_tls = field_validator("mqtt_tls", mode="before")(validate_boolean)

validate_port = field_validator("port", mode="before")(validate_port)

validate_raw_data = field_validator("raw_data", mode="before")(validate_boolean)

@model_validator(mode="before")
Expand Down
21 changes: 21 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,27 @@ def test_precision() -> None:
assert configs.default_config.precision == 2


@pytest.mark.parametrize(
"config,is_valid",
(
[TEST_CONFIG_JSON | {CONF_PORT: 1883}, True],
[TEST_CONFIG_JSON | {CONF_PORT: 9000}, True],
[TEST_CONFIG_JSON | {CONF_PORT: "1883"}, True],
[TEST_CONFIG_JSON | {CONF_PORT: "Not a port"}, False],
),
)
def test_port(config: dict[str, Any], is_valid: bool) -> None:
"""Test validating a port.
Args:
config: A configuration dictionary.
is_valid: Whether the configuration is valid.
"""
if not is_valid:
with pytest.raises(ConfigError):
_ = Configs(config)


@pytest.mark.parametrize(
"config,verbose_value",
[
Expand Down

0 comments on commit 04aaa50

Please sign in to comment.