diff --git a/custom_components/zha_toolkit/__init__.py b/custom_components/zha_toolkit/__init__.py index eb8209c..b19268e 100644 --- a/custom_components/zha_toolkit/__init__.py +++ b/custom_components/zha_toolkit/__init__.py @@ -1,10 +1,11 @@ import importlib import logging +from typing import Optional import homeassistant.helpers.config_validation as cv import voluptuous as vol -from homeassistant.util import dt as dt_util from homeassistant.components.zha.core.gateway import ZHAGateway +from homeassistant.util import dt as dt_util from zigpy import types as t from . import params as PARDEFS @@ -29,7 +30,7 @@ LOADED_VERSION = "" try: - DEFAULT_OTAU # type:ignore[used-before-def] + DEFAULT_OTAU # type:ignore[used-before-def] except NameError: DEFAULT_OTAU = "/config/zigpy_ota" @@ -669,13 +670,20 @@ async def toolkit_service(service): LOGGER.info("Running ZHA Toolkit service: %s", service) global LOADED_VERSION # pylint: disable=global-variable-not-assigned - try: - zha_gw: ZHAGateway = hass_ref.data["zha"].gateway - except AttributeError: + zha = hass_ref.data["zha"] + zha_gw: Optional[ZHAGateway] = zha.get("gateway", None) + if zha_gw is None: + zha_gw = zha.get("zha_gateway", None) + if zha_gw is None: LOGGER.error( - "Missing hass.data['zha'].gateway - not running %s", + "Missing hass.data['zha']/gateway - not found/running %s - on %r", service, + zha, ) + LOGGER.debug( + "Got hass.data['zha']/gateway %r", + zha_gw, + ) # importlib.reload(PARDEFS) # S = PARDEFS.SERVICES @@ -708,7 +716,7 @@ async def toolkit_service(service): # Decode parameters params = u.extractParams(service) - app = zha_gw.application_controller + app = zha_gw.application_controller # type: ignore ieee = await u.get_ieee(app, zha_gw, ieee_str) @@ -767,7 +775,7 @@ async def toolkit_service(service): handler_result = None try: handler_result = await handler( - zha_gw.application_controller, + zha_gw.application_controller, # type: ignore zha_gw, ieee, cmd, @@ -791,15 +799,15 @@ async def toolkit_service(service): LOGGER.debug( "Fire %s -> %s", params[p.EVT_SUCCESS], event_data ) - zha_gw.hass.bus.fire(params[p.EVT_SUCCESS], event_data) + u.get_hass(zha_gw).bus.fire(params[p.EVT_SUCCESS], event_data) else: if params[p.EVT_FAIL] is not None: LOGGER.debug("Fire %s -> %s", params[p.EVT_FAIL], event_data) - zha_gw.hass.bus.fire(params[p.EVT_FAIL], event_data) + u.get_hass(zha_gw).bus.fire(params[p.EVT_FAIL], event_data) if params[p.EVT_DONE] is not None: LOGGER.debug("Fire %s -> %s", params[p.EVT_DONE], event_data) - zha_gw.hass.bus.fire(params[p.EVT_DONE], event_data) + u.get_hass(zha_gw).bus.fire(params[p.EVT_DONE], event_data) if handler_exception is not None: raise handler_exception @@ -906,4 +914,4 @@ async def _register_services(hass): async def command_handler_register_services( app, listener, ieee, cmd, data, service, params, event_data ): - await _register_services(listener.hass) + await _register_services(u.get_hass(listener)) diff --git a/custom_components/zha_toolkit/ha.py b/custom_components/zha_toolkit/ha.py index 2209bfa..4210059 100644 --- a/custom_components/zha_toolkit/ha.py +++ b/custom_components/zha_toolkit/ha.py @@ -21,7 +21,9 @@ async def ha_set_state( # noqa: C901 state_template_str = params[p.STATE_VALUE_TEMPLATE] if state_template_str is not None: - template = Template("{{ " + state_template_str + " }}", listener.hass) + template = Template( + "{{ " + state_template_str + " }}", u.get_hass(listener) + ) new_value = template.async_render(value=val, attr_val=val) val = new_value @@ -40,7 +42,7 @@ async def ha_set_state( # noqa: C901 val, ) u.set_state( - listener.hass, + u.get_hass(listener), params[p.STATE_ID], val, key=params[p.STATE_ATTR], diff --git a/custom_components/zha_toolkit/neighbours.py b/custom_components/zha_toolkit/neighbours.py index 49e9d3f..6e84a1f 100644 --- a/custom_components/zha_toolkit/neighbours.py +++ b/custom_components/zha_toolkit/neighbours.py @@ -28,7 +28,7 @@ async def get_routes_and_neighbours( ieee_tail = "".join([f"{o:02X}" for o in device.ieee]) fname = os.path.join( - listener.hass.config.config_dir, + u.get_hass(listener).config.config_dir, "scans", f"routes_and_neighbours_{ieee_tail}.json", ) @@ -75,7 +75,7 @@ async def all_routes_and_neighbours( event_data["result"] = all_routes all_routes_name = os.path.join( - listener.hass.config.config_dir, + u.get_hass(listener).config.config_dir, "scans", "all_routes_and_neighbours.json", ) diff --git a/custom_components/zha_toolkit/utils.py b/custom_components/zha_toolkit/utils.py index 6568c9b..c9cea6f 100644 --- a/custom_components/zha_toolkit/utils.py +++ b/custom_components/zha_toolkit/utils.py @@ -9,6 +9,7 @@ import typing from enum import Enum +from homeassistant.components.zha.core.gateway import ZHAGateway from pkg_resources import get_distribution, parse_version from zigpy import types as t from zigpy.exceptions import ControllerException, DeliveryError @@ -274,10 +275,12 @@ async def get_ieee(app, listener, ref): # Todo: check if NWK address entity_registry = ( # Deprecated >= 2022.6.0 - await listener.hass.helpers.entity_registry.async_get_registry() + await get_hass( + listener + ).helpers.entity_registry.async_get_registry() if not is_ha_ge("2022.6") - else listener.hass.helpers.entity_registry.async_get( - listener.hass + else get_hass(listener).helpers.entity_registry.async_get( + get_hass(listener) ) ) # LOGGER.debug("registry %s",entity_registry) @@ -291,10 +294,12 @@ async def get_ieee(app, listener, ref): device_registry = ( # Deprecated >= 2022.6.0 - await listener.hass.helpers.device_registry.async_get_registry() + await get_hass( + listener + ).helpers.device_registry.async_get_registry() if not is_ha_ge("2022.6") - else listener.hass.helpers.device_registry.async_get( - listener.hass + else get_hass(listener).helpers.device_registry.async_get( + get_hass(listener) ) ) registry_device = device_registry.async_get(registry_entity.device_id) @@ -445,7 +450,7 @@ def write_json_to_file( if listener is None or subdir == "local": base_dir = os.path.dirname(__file__) else: - base_dir = listener.hass.config.config_dir + base_dir = get_hass(listener).config.config_dir out_dir = os.path.join(base_dir, subdir) if not os.path.isdir(out_dir): @@ -477,7 +482,7 @@ def append_to_csvfile( if listener is None or subdir == "local": base_dir = os.path.dirname(__file__) else: - base_dir = listener.hass.config.config_dir + base_dir = get_hass(listener).config.config_dir out_dir = os.path.join(base_dir, subdir) if not os.path.isdir(out_dir): @@ -970,3 +975,11 @@ def is_zigpy_ge(version: str) -> bool: def is_ha_ge(version: str) -> bool: """Test if zigpy library is newer than version""" return parse_version(getHaVersion()) >= parse_version(version) + + +def get_hass(gateway: ZHAGateway): + """HA Version independent way of getting hass from gateway""" + hass = getattr(gateway, "_hass", None) + if hass is None: + hass = getattr(gateway, "hass", None) + return hass diff --git a/custom_components/zha_toolkit/zcl_attr.py b/custom_components/zha_toolkit/zcl_attr.py index 83892b8..64621ce 100644 --- a/custom_components/zha_toolkit/zcl_attr.py +++ b/custom_components/zha_toolkit/zcl_attr.py @@ -532,7 +532,8 @@ async def attr_write( # noqa: C901 ) template = Template( - "{{ " + state_template_str + " }}", listener.hass + "{{ " + state_template_str + " }}", + u.get_hass(listener), ) try: val = template.async_render(value=val, attr_val=val) @@ -561,7 +562,7 @@ async def attr_write( # noqa: C901 attr_id, ) u.set_state( - listener.hass, + u.get_hass(listener), params[p.STATE_ID], val, key=params[p.STATE_ATTR],