Skip to content

Commit

Permalink
Add backward compatibility while allowing HA2023.10
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeweerd committed Oct 5, 2023
1 parent d5b4b4e commit bd241c6
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 26 deletions.
32 changes: 20 additions & 12 deletions custom_components/zha_toolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import importlib
import logging
from typing import Optional

import homeassistant.helpers.config_validation as cv
import voluptuous as vol
from homeassistant.util import dt as dt_util
from homeassistant.components.zha.core.gateway import ZHAGateway
from homeassistant.util import dt as dt_util
from zigpy import types as t

from . import params as PARDEFS
Expand All @@ -29,7 +30,7 @@
LOADED_VERSION = ""

try:
DEFAULT_OTAU # type:ignore[used-before-def]
DEFAULT_OTAU # type:ignore[used-before-def]
except NameError:
DEFAULT_OTAU = "/config/zigpy_ota"

Expand Down Expand Up @@ -669,13 +670,20 @@ async def toolkit_service(service):
LOGGER.info("Running ZHA Toolkit service: %s", service)
global LOADED_VERSION # pylint: disable=global-variable-not-assigned

try:
zha_gw: ZHAGateway = hass_ref.data["zha"].gateway
except AttributeError:
zha = hass_ref.data["zha"]
zha_gw: Optional[ZHAGateway] = zha.get("gateway", None)
if zha_gw is None:
zha_gw = zha.get("zha_gateway", None)
if zha_gw is None:
LOGGER.error(
"Missing hass.data['zha'].gateway - not running %s",
"Missing hass.data['zha']/gateway - not found/running %s - on %r",
service,
zha,
)
LOGGER.debug(
"Got hass.data['zha']/gateway %r",
zha_gw,
)

# importlib.reload(PARDEFS)
# S = PARDEFS.SERVICES
Expand Down Expand Up @@ -708,7 +716,7 @@ async def toolkit_service(service):
# Decode parameters
params = u.extractParams(service)

app = zha_gw.application_controller
app = zha_gw.application_controller # type: ignore

ieee = await u.get_ieee(app, zha_gw, ieee_str)

Expand Down Expand Up @@ -767,7 +775,7 @@ async def toolkit_service(service):
handler_result = None
try:
handler_result = await handler(
zha_gw.application_controller,
zha_gw.application_controller, # type: ignore
zha_gw,
ieee,
cmd,
Expand All @@ -791,15 +799,15 @@ async def toolkit_service(service):
LOGGER.debug(
"Fire %s -> %s", params[p.EVT_SUCCESS], event_data
)
zha_gw.hass.bus.fire(params[p.EVT_SUCCESS], event_data)
u.get_hass(zha_gw).bus.fire(params[p.EVT_SUCCESS], event_data)
else:
if params[p.EVT_FAIL] is not None:
LOGGER.debug("Fire %s -> %s", params[p.EVT_FAIL], event_data)
zha_gw.hass.bus.fire(params[p.EVT_FAIL], event_data)
u.get_hass(zha_gw).bus.fire(params[p.EVT_FAIL], event_data)

if params[p.EVT_DONE] is not None:
LOGGER.debug("Fire %s -> %s", params[p.EVT_DONE], event_data)
zha_gw.hass.bus.fire(params[p.EVT_DONE], event_data)
u.get_hass(zha_gw).bus.fire(params[p.EVT_DONE], event_data)

if handler_exception is not None:
raise handler_exception
Expand Down Expand Up @@ -906,4 +914,4 @@ async def _register_services(hass):
async def command_handler_register_services(
app, listener, ieee, cmd, data, service, params, event_data
):
await _register_services(listener.hass)
await _register_services(u.get_hass(listener))
6 changes: 4 additions & 2 deletions custom_components/zha_toolkit/ha.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ async def ha_set_state( # noqa: C901

state_template_str = params[p.STATE_VALUE_TEMPLATE]
if state_template_str is not None:
template = Template("{{ " + state_template_str + " }}", listener.hass)
template = Template(
"{{ " + state_template_str + " }}", u.get_hass(listener)
)
new_value = template.async_render(value=val, attr_val=val)
val = new_value

Expand All @@ -40,7 +42,7 @@ async def ha_set_state( # noqa: C901
val,
)
u.set_state(
listener.hass,
u.get_hass(listener),
params[p.STATE_ID],
val,
key=params[p.STATE_ATTR],
Expand Down
4 changes: 2 additions & 2 deletions custom_components/zha_toolkit/neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def get_routes_and_neighbours(
ieee_tail = "".join([f"{o:02X}" for o in device.ieee])

fname = os.path.join(
listener.hass.config.config_dir,
u.get_hass(listener).config.config_dir,
"scans",
f"routes_and_neighbours_{ieee_tail}.json",
)
Expand Down Expand Up @@ -75,7 +75,7 @@ async def all_routes_and_neighbours(
event_data["result"] = all_routes

all_routes_name = os.path.join(
listener.hass.config.config_dir,
u.get_hass(listener).config.config_dir,
"scans",
"all_routes_and_neighbours.json",
)
Expand Down
29 changes: 21 additions & 8 deletions custom_components/zha_toolkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import typing
from enum import Enum

from homeassistant.components.zha.core.gateway import ZHAGateway
from pkg_resources import get_distribution, parse_version
from zigpy import types as t
from zigpy.exceptions import ControllerException, DeliveryError
Expand Down Expand Up @@ -274,10 +275,12 @@ async def get_ieee(app, listener, ref):
# Todo: check if NWK address
entity_registry = (
# Deprecated >= 2022.6.0
await listener.hass.helpers.entity_registry.async_get_registry()
await get_hass(
listener
).helpers.entity_registry.async_get_registry()
if not is_ha_ge("2022.6")
else listener.hass.helpers.entity_registry.async_get(
listener.hass
else get_hass(listener).helpers.entity_registry.async_get(
get_hass(listener)
)
)
# LOGGER.debug("registry %s",entity_registry)
Expand All @@ -291,10 +294,12 @@ async def get_ieee(app, listener, ref):

device_registry = (
# Deprecated >= 2022.6.0
await listener.hass.helpers.device_registry.async_get_registry()
await get_hass(
listener
).helpers.device_registry.async_get_registry()
if not is_ha_ge("2022.6")
else listener.hass.helpers.device_registry.async_get(
listener.hass
else get_hass(listener).helpers.device_registry.async_get(
get_hass(listener)
)
)
registry_device = device_registry.async_get(registry_entity.device_id)
Expand Down Expand Up @@ -445,7 +450,7 @@ def write_json_to_file(
if listener is None or subdir == "local":
base_dir = os.path.dirname(__file__)
else:
base_dir = listener.hass.config.config_dir
base_dir = get_hass(listener).config.config_dir

out_dir = os.path.join(base_dir, subdir)
if not os.path.isdir(out_dir):
Expand Down Expand Up @@ -477,7 +482,7 @@ def append_to_csvfile(
if listener is None or subdir == "local":
base_dir = os.path.dirname(__file__)
else:
base_dir = listener.hass.config.config_dir
base_dir = get_hass(listener).config.config_dir

out_dir = os.path.join(base_dir, subdir)
if not os.path.isdir(out_dir):
Expand Down Expand Up @@ -970,3 +975,11 @@ def is_zigpy_ge(version: str) -> bool:
def is_ha_ge(version: str) -> bool:
"""Test if zigpy library is newer than version"""
return parse_version(getHaVersion()) >= parse_version(version)


def get_hass(gateway: ZHAGateway):
"""HA Version independent way of getting hass from gateway"""
hass = getattr(gateway, "_hass", None)
if hass is None:
hass = getattr(gateway, "hass", None)
return hass
5 changes: 3 additions & 2 deletions custom_components/zha_toolkit/zcl_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ async def attr_write( # noqa: C901
)

template = Template(
"{{ " + state_template_str + " }}", listener.hass
"{{ " + state_template_str + " }}",
u.get_hass(listener),
)
try:
val = template.async_render(value=val, attr_val=val)
Expand Down Expand Up @@ -561,7 +562,7 @@ async def attr_write( # noqa: C901
attr_id,
)
u.set_state(
listener.hass,
u.get_hass(listener),
params[p.STATE_ID],
val,
key=params[p.STATE_ATTR],
Expand Down

0 comments on commit bd241c6

Please sign in to comment.