Skip to content

Commit

Permalink
Only publish MQTT Discovery payload when config changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bachya committed Oct 29, 2023
1 parent 0694a63 commit e031092
Show file tree
Hide file tree
Showing 2 changed files with 552 additions and 533 deletions.
199 changes: 111 additions & 88 deletions ecowitt2mqtt/helpers/publisher/mqtt/hass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import Any, TypedDict
from dataclasses import asdict, dataclass
from typing import TypedDict

from aiomqtt import Client, MqttError

Expand Down Expand Up @@ -152,12 +152,37 @@ class EntityDescription:
state_class: str | None = None


@dataclass(frozen=True)
class HassDiscoveryDevice:
"""Define an MQTT Discovery device."""

identifiers: list[str]
manufacturer: str
model: str
name: str
sw_version: str


@dataclass
class HassDiscoveryPayload:
"""Define a MQTT Discovery configuration for an entity."""
class HassDiscoveryInfo:
"""Define an MQTT Discovery payload."""

payload: dict[str, Any]
topic: str
availability_topic: str
config_topic: str
device: HassDiscoveryDevice
json_attributes_topic: str
name: str
retain: bool
state_topic: str
unique_id: str

device_class: str | None = None
entity_category: str | None = None
icon: str | None = None
object_id: str | None = None
qos: int = 1
state_class: str | None = None
unit_of_measurement: str | None = None


AVAILABILITY_OFFLINE = "offline"
Expand Down Expand Up @@ -443,23 +468,12 @@ def __init__(self, config: Config, client: Client) -> None:
"""
super().__init__(config, client)

self._discovery_payloads: dict[str, HassDiscoveryPayload] = {}

def _generate_discovery_payload( # pylint: disable=too-many-branches
self, device: Device, payload_key: str, data_point: CalculatedDataPoint
) -> HassDiscoveryPayload:
"""Generate a discovery payload for an entity.
Args:
device: A Device object.
payload_key: The Ecowitt payload key.
data_point: A CalculatedDataPoint object.
self._discovery_infos: dict[str, HassDiscoveryInfo] = {}

Returns:
A parsed HassDiscoveryPayload object.
"""
# Since batteries can be one of many different strategies, we calculate an
# entity description at runtime:
def _get_data_point_key(
self, payload_key: str, data_point: CalculatedDataPoint
) -> str:
"""Get the data point key."""
if data_point.data_point_key in (DATA_POINT_GLOB_BATT, DATA_POINT_GLOB_VOLT):
strategy = get_battery_strategy(self._config, payload_key)
if strategy == BatteryStrategy.BOOLEAN:
Expand All @@ -471,60 +485,58 @@ def _generate_discovery_payload( # pylint: disable=too-many-branches
else:
data_point_key = data_point.data_point_key

return data_point_key

def _get_discovery_info(
self, device: Device, payload_key: str, data_point: CalculatedDataPoint
) -> HassDiscoveryInfo:
"""Get the discovery payload from a payload."""
base_topic = (
f"{self._config.hass_discovery_prefix}/{PLATFORM_MAP[data_point.data_type]}"
f"/{device.unique_id}/{payload_key}"
)

config = {
"availability_topic": f"{base_topic}/availability",
"device": {
"identifiers": [device.unique_id],
"manufacturer": device.manufacturer,
"model": device.model,
"name": device.name,
"sw_version": device.station_type,
},
"json_attributes_topic": f"{base_topic}/attributes",
"name": payload_key,
"qos": 1,
"retain": self._config.mqtt_retain,
"state_topic": f"{base_topic}/state",
"unique_id": f"{device.unique_id}_{payload_key}",
}
if self._config.hass_entity_id_prefix:
config["object_id"] = f"{self._config.hass_entity_id_prefix}_{payload_key}"

payload = self._discovery_payloads[payload_key] = HassDiscoveryPayload(
config, f"{base_topic}/config"
discovery = HassDiscoveryInfo(
availability_topic=f"{base_topic}/availability",
config_topic=f"{base_topic}/config",
device=HassDiscoveryDevice(
identifiers=[device.unique_id],
manufacturer=device.manufacturer,
model=device.model,
name=device.name,
sw_version=device.station_type,
),
json_attributes_topic=f"{base_topic}/attributes",
name=payload_key,
retain=self._config.mqtt_retain,
state_topic=f"{base_topic}/state",
unique_id=f"{device.unique_id}_{payload_key}",
)

if data_point.attributes:
payload.payload["json_attributes_topic"] = f"{base_topic}/attributes"
if self._config.hass_entity_id_prefix:
discovery.object_id = f"{self._config.hass_entity_id_prefix}_{payload_key}"
if data_point.unit:
payload.payload["unit_of_measurement"] = data_point.unit
discovery.unit_of_measurement = data_point.unit

# If we have an entity description, use it:
data_point_key = self._get_data_point_key(payload_key, data_point)
if description := ENTITY_DESCRIPTIONS.get(data_point_key):
for discovery_key, value in (
("device_class", description.device_class),
("entity_category", description.entity_category),
("icon", description.icon),
(
"state_class",
STATE_CLASS_OVERRIDES.get(payload_key, description.state_class),
),
):
if not value:
continue
payload.payload[discovery_key] = value
if description.device_class:
discovery.device_class = description.device_class
if description.entity_category:
discovery.entity_category = description.entity_category
if description.icon:
discovery.icon = description.icon
if description.state_class:
discovery.state_class = STATE_CLASS_OVERRIDES.get(
payload_key, description.state_class
)
else:
LOGGER.debug(
'Missing entity description for "%s" (please report it!)',
payload_key,
"No entity description found for data point %s", data_point_key
)

return payload
return discovery

async def async_publish(self, data: dict[str, CalculatedValueType]) -> None:
"""Publish to MQTT.
Expand All @@ -538,38 +550,49 @@ async def async_publish(self, data: dict[str, CalculatedValueType]) -> None:
processed_data = ProcessedData(self._config, data)
tasks: list[asyncio.Task] = []

try:
for payload_key, data_point in processed_data.output.items():
discovery_payload = self._generate_discovery_payload(
processed_data.device, payload_key, data_point
for payload_key, data_point in processed_data.output.items():
discovery_info = self._get_discovery_info(
processed_data.device, payload_key, data_point
)

if self._discovery_infos.get(discovery_info.unique_id) != discovery_info:
LOGGER.debug(
"Publishing discovery info for %s", discovery_info.unique_id
)
self._discovery_infos[discovery_info.unique_id] = discovery_info
tasks.append(
asyncio.create_task(
self._client.publish(
discovery_info.config_topic,
payload=generate_mqtt_payload(asdict(discovery_info)),
retain=self._config.mqtt_retain,
)
)
)

for topic, payload in (
(discovery_payload.topic, discovery_payload.payload),
(
discovery_payload.payload["availability_topic"],
get_availability_payload(data_point),
),
(
discovery_payload.payload["json_attributes_topic"],
data_point.attributes,
),
(
discovery_payload.payload["state_topic"],
data_point.value,
),
):
tasks.append(
asyncio.create_task(
self._client.publish(
topic,
payload=generate_mqtt_payload(payload),
retain=self._config.mqtt_retain,
)
for topic, payload in (
(
discovery_info.availability_topic,
get_availability_payload(data_point),
),
(discovery_info.json_attributes_topic, data_point.attributes),
(
discovery_info.state_topic,
data_point.value,
),
):
tasks.append(
asyncio.create_task(
self._client.publish(
topic,
payload=generate_mqtt_payload(payload),
retain=self._config.mqtt_retain,
)
)
)

await asyncio.gather(*tasks)
try:
await asyncio.gather(*tasks)
except MqttError:
for task in tasks:
task.cancel()
Expand Down
Loading

0 comments on commit e031092

Please sign in to comment.