Skip to content

Commit

Permalink
Merge branch 'dev' into docker
Browse files Browse the repository at this point in the history
  • Loading branch information
bachya authored Oct 29, 2023
2 parents 389cd46 + 073d6a7 commit 5b5d145
Show file tree
Hide file tree
Showing 2 changed files with 562 additions and 532 deletions.
208 changes: 121 additions & 87 deletions ecowitt2mqtt/helpers/publisher/mqtt/hass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any, TypedDict

from aiomqtt import Client, MqttError
Expand Down Expand Up @@ -152,12 +152,43 @@ class EntityDescription:
state_class: str | None = None


@dataclass(frozen=True)
class HassDiscoveryDevice:
"""Define an MQTT Discovery device."""

identifiers: list[str]
manufacturer: str
model: str
name: str
sw_version: str


@dataclass
class HassDiscoveryPayload:
"""Define a MQTT Discovery configuration for an entity."""
class HassDiscoveryInfo:
"""Define an MQTT Discovery payload."""

payload: dict[str, Any]
topic: str
availability_topic: str
config_topic: str
device: HassDiscoveryDevice
json_attributes_topic: str
name: str
retain: bool
state_topic: str
unique_id: str

device_class: str | None = None
entity_category: str | None = None
icon: str | None = None
object_id: str | None = None
qos: int = 1
state_class: str | None = None
unit_of_measurement: str | None = None

@staticmethod
def dict_factory(x: list[tuple[str, Any]]) -> dict[str, Any]:
"""Define a default dict factory for the dataclass."""
exclude_fields = ("config_topic",)
return {k: v for (k, v) in x if ((v is not None) and (k not in exclude_fields))}


AVAILABILITY_OFFLINE = "offline"
Expand Down Expand Up @@ -443,23 +474,12 @@ def __init__(self, config: Config, client: Client) -> None:
"""
super().__init__(config, client)

self._discovery_payloads: dict[str, HassDiscoveryPayload] = {}

def _generate_discovery_payload( # pylint: disable=too-many-branches
self, device: Device, payload_key: str, data_point: CalculatedDataPoint
) -> HassDiscoveryPayload:
"""Generate a discovery payload for an entity.
Args:
device: A Device object.
payload_key: The Ecowitt payload key.
data_point: A CalculatedDataPoint object.
self._discovery_infos: dict[str, HassDiscoveryInfo] = {}

Returns:
A parsed HassDiscoveryPayload object.
"""
# Since batteries can be one of many different strategies, we calculate an
# entity description at runtime:
def _get_data_point_key(
self, payload_key: str, data_point: CalculatedDataPoint
) -> str:
"""Get the data point key."""
if data_point.data_point_key in (DATA_POINT_GLOB_BATT, DATA_POINT_GLOB_VOLT):
strategy = get_battery_strategy(self._config, payload_key)
if strategy == BatteryStrategy.BOOLEAN:
Expand All @@ -471,60 +491,58 @@ def _generate_discovery_payload( # pylint: disable=too-many-branches
else:
data_point_key = data_point.data_point_key

return data_point_key

def _get_discovery_info(
self, device: Device, payload_key: str, data_point: CalculatedDataPoint
) -> HassDiscoveryInfo:
"""Get the discovery payload from a payload."""
base_topic = (
f"{self._config.hass_discovery_prefix}/{PLATFORM_MAP[data_point.data_type]}"
f"/{device.unique_id}/{payload_key}"
)

config = {
"availability_topic": f"{base_topic}/availability",
"device": {
"identifiers": [device.unique_id],
"manufacturer": device.manufacturer,
"model": device.model,
"name": device.name,
"sw_version": device.station_type,
},
"json_attributes_topic": f"{base_topic}/attributes",
"name": payload_key,
"qos": 1,
"retain": self._config.mqtt_retain,
"state_topic": f"{base_topic}/state",
"unique_id": f"{device.unique_id}_{payload_key}",
}
if self._config.hass_entity_id_prefix:
config["object_id"] = f"{self._config.hass_entity_id_prefix}_{payload_key}"

payload = self._discovery_payloads[payload_key] = HassDiscoveryPayload(
config, f"{base_topic}/config"
discovery = HassDiscoveryInfo(
availability_topic=f"{base_topic}/availability",
config_topic=f"{base_topic}/config",
device=HassDiscoveryDevice(
identifiers=[device.unique_id],
manufacturer=device.manufacturer,
model=device.model,
name=device.name,
sw_version=device.station_type,
),
json_attributes_topic=f"{base_topic}/attributes",
name=payload_key,
retain=self._config.mqtt_retain,
state_topic=f"{base_topic}/state",
unique_id=f"{device.unique_id}_{payload_key}",
)

if data_point.attributes:
payload.payload["json_attributes_topic"] = f"{base_topic}/attributes"
if self._config.hass_entity_id_prefix:
discovery.object_id = f"{self._config.hass_entity_id_prefix}_{payload_key}"
if data_point.unit:
payload.payload["unit_of_measurement"] = data_point.unit
discovery.unit_of_measurement = data_point.unit

# If we have an entity description, use it:
data_point_key = self._get_data_point_key(payload_key, data_point)
if description := ENTITY_DESCRIPTIONS.get(data_point_key):
for discovery_key, value in (
("device_class", description.device_class),
("entity_category", description.entity_category),
("icon", description.icon),
(
"state_class",
STATE_CLASS_OVERRIDES.get(payload_key, description.state_class),
),
):
if not value:
continue
payload.payload[discovery_key] = value
if description.device_class:
discovery.device_class = description.device_class
if description.entity_category:
discovery.entity_category = description.entity_category
if description.icon:
discovery.icon = description.icon
if description.state_class:
discovery.state_class = STATE_CLASS_OVERRIDES.get(
payload_key, description.state_class
)
else:
LOGGER.debug(
'Missing entity description for "%s" (please report it!)',
payload_key,
"No entity description found for data point %s", data_point_key
)

return payload
return discovery

async def async_publish(self, data: dict[str, CalculatedValueType]) -> None:
"""Publish to MQTT.
Expand All @@ -538,38 +556,54 @@ async def async_publish(self, data: dict[str, CalculatedValueType]) -> None:
processed_data = ProcessedData(self._config, data)
tasks: list[asyncio.Task] = []

try:
for payload_key, data_point in processed_data.output.items():
discovery_payload = self._generate_discovery_payload(
processed_data.device, payload_key, data_point
for payload_key, data_point in processed_data.output.items():
discovery_info = self._get_discovery_info(
processed_data.device, payload_key, data_point
)

if self._discovery_infos.get(discovery_info.unique_id) != discovery_info:
LOGGER.debug(
"Publishing discovery info for %s", discovery_info.unique_id
)
self._discovery_infos[discovery_info.unique_id] = discovery_info
tasks.append(
asyncio.create_task(
self._client.publish(
discovery_info.config_topic,
payload=generate_mqtt_payload(
asdict(
discovery_info,
dict_factory=HassDiscoveryInfo.dict_factory,
)
),
retain=self._config.mqtt_retain,
)
)
)

for topic, payload in (
(discovery_payload.topic, discovery_payload.payload),
(
discovery_payload.payload["availability_topic"],
get_availability_payload(data_point),
),
(
discovery_payload.payload["json_attributes_topic"],
data_point.attributes,
),
(
discovery_payload.payload["state_topic"],
data_point.value,
),
):
tasks.append(
asyncio.create_task(
self._client.publish(
topic,
payload=generate_mqtt_payload(payload),
retain=self._config.mqtt_retain,
)
for topic, payload in (
(
discovery_info.availability_topic,
get_availability_payload(data_point),
),
(discovery_info.json_attributes_topic, data_point.attributes),
(
discovery_info.state_topic,
data_point.value,
),
):
tasks.append(
asyncio.create_task(
self._client.publish(
topic,
payload=generate_mqtt_payload(payload),
retain=self._config.mqtt_retain,
)
)
)

await asyncio.gather(*tasks)
try:
await asyncio.gather(*tasks)
except MqttError:
for task in tasks:
task.cancel()
Expand Down
Loading

0 comments on commit 5b5d145

Please sign in to comment.