From 8c6172210ef96df619ef9b75b52f430754169ef6 Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 03:14:58 +0100 Subject: [PATCH] feat/transformers closes https://github.com/JarbasHiveMind/HiveMind-core/issues/82 --- readme.md => README.md | 0 hivemind_core/protocol.py | 64 +++++++++++++----- hivemind_core/transformers.py | 122 ++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- 4 files changed, 170 insertions(+), 19 deletions(-) rename readme.md => README.md (100%) create mode 100644 hivemind_core/transformers.py diff --git a/readme.md b/README.md similarity index 100% rename from readme.md rename to README.md diff --git a/hivemind_core/protocol.py b/hivemind_core/protocol.py index aa56b92..2d023c7 100644 --- a/hivemind_core/protocol.py +++ b/hivemind_core/protocol.py @@ -3,14 +3,6 @@ from enum import Enum, IntEnum from typing import List, Dict, Optional -from ovos_bus_client import MessageBusClient -from ovos_bus_client.message import Message -from ovos_bus_client.session import Session -from ovos_utils.log import LOG -from poorman_handshake import HandShake, PasswordHandShake -from tornado import ioloop -from tornado.websocket import WebSocketHandler - from hivemind_bus_client.message import HiveMessage, HiveMessageType from hivemind_bus_client.serialization import decode_bitstring, get_bitstring from hivemind_bus_client.util import ( @@ -19,6 +11,17 @@ decrypt_from_json, encrypt_as_json, ) +from ovos_bus_client import MessageBusClient +from ovos_bus_client.message import Message +from ovos_bus_client.session import Session +from ovos_bus_client.util import get_message_lang +from ovos_config import Configuration +from ovos_utils.log import LOG +from poorman_handshake import HandShake, PasswordHandShake +from tornado import ioloop +from tornado.websocket import WebSocketHandler + +from hivemind_core.transformers import MetadataTransformersService, UtteranceTransformersService class ProtocolVersion(IntEnum): @@ -253,11 +256,18 @@ class HiveMindListenerProtocol: mycroft_bus_callback = None # slave asked to inject payload into mycroft bus shared_bus_callback = None # passive sharing of slave device bus (info) + utterance_plugins: UtteranceTransformersService = None + metadata_plugins: MetadataTransformersService = None + def bind(self, websocket, bus): websocket.protocol = self self.internal_protocol = HiveMindListenerInternalProtocol(bus) self.internal_protocol.register_bus_handlers() + config = Configuration().get("hivemind", {}) + self.utterance_plugins = UtteranceTransformersService(bus, config=config) + self.metadata_plugins = MetadataTransformersService(bus, config=config) + def get_bus(self, client: HiveMindClientConnection): # allow subclasses to use dedicated bus per client return self.internal_protocol.bus @@ -303,9 +313,9 @@ def handle_new_client(self, client: HiveMindClientConnection): "max_protocol_version": max_version, "binarize": True, # report we support the binarization scheme "preshared_key": client.crypto_key - is not None, # do we have a pre-shared key (V0 proto) + is not None, # do we have a pre-shared key (V0 proto) "password": client.pswd_handshake - is not None, # is password available (V1 proto, replaces pre-shared key) + is not None, # is password available (V1 proto, replaces pre-shared key) "crypto_required": self.require_crypto, # do we allow unencrypted payloads } msg = HiveMessage(HiveMessageType.HANDSHAKE, payload) @@ -381,7 +391,7 @@ def handle_message(self, message: HiveMessage, client: HiveMindClientConnection) # HiveMind protocol messages - from slave -> master def handle_unknown_message( - self, message: HiveMessage, client: HiveMindClientConnection + self, message: HiveMessage, client: HiveMindClientConnection ): """message handler for non default message types, subclasses can handle their own types here @@ -390,13 +400,13 @@ def handle_unknown_message( """ def handle_binary_message( - self, message: HiveMessage, client: HiveMindClientConnection + self, message: HiveMessage, client: HiveMindClientConnection ): assert message.msg_type == HiveMessageType.BINARY # TODO def handle_handshake_message( - self, message: HiveMessage, client: HiveMindClientConnection + self, message: HiveMessage, client: HiveMindClientConnection ): LOG.debug("handshake received, generating session key") payload = message.payload @@ -450,15 +460,33 @@ def handle_handshake_message( msg = HiveMessage(HiveMessageType.HANDSHAKE, payload) client.send(msg) # client can recreate crypto_key on his side now + def _handle_transformers(self, message: Message) -> Message: + """ + Pipe utterance through transformer plugins to get more metadata. + Utterances may be modified by any parser and context overwritten + """ + lang = get_message_lang(message) # per query lang or default Configuration lang + original = utterances = message.data.get('utterances', []) + message.context["lang"] = lang + utterances, message.context = self.utterance_plugins.transform(utterances, message.context) + if original != utterances: + message.data["utterances"] = utterances + LOG.debug(f"utterances transformed: {original} -> {utterances}") + message.context = self.metadata_plugins.transform(message.context) + return message + def handle_bus_message( - self, message: HiveMessage, client: HiveMindClientConnection + self, message: HiveMessage, client: HiveMindClientConnection ): + if message.payload.msg_type == "recognizer_loop:utterance": + message._payload = self._handle_transformers(message.payload).serialize() + self.handle_inject_mycroft_msg(message.payload, client) if self.mycroft_bus_callback: self.mycroft_bus_callback(message.payload) def handle_broadcast_message( - self, message: HiveMessage, client: HiveMindClientConnection + self, message: HiveMessage, client: HiveMindClientConnection ): """ message (HiveMessage): HiveMind message object @@ -492,7 +520,7 @@ def _unpack_message(self, message: HiveMessage, client: HiveMindClientConnection return pload def handle_propagate_message( - self, message: HiveMessage, client: HiveMindClientConnection + self, message: HiveMessage, client: HiveMindClientConnection ): """ message (HiveMessage): HiveMind message object @@ -533,7 +561,7 @@ def handle_propagate_message( bus.emit(message) def handle_escalate_message( - self, message: HiveMessage, client: HiveMindClientConnection + self, message: HiveMessage, client: HiveMindClientConnection ): """ message (HiveMessage): HiveMind message object @@ -578,7 +606,7 @@ def update_slave_session(self, message: Message, client: HiveMindClientConnectio return message def handle_inject_mycroft_msg( - self, message: Message, client: HiveMindClientConnection + self, message: Message, client: HiveMindClientConnection ): """ message (Message): mycroft bus message object diff --git a/hivemind_core/transformers.py b/hivemind_core/transformers.py new file mode 100644 index 0000000..76e8e0f --- /dev/null +++ b/hivemind_core/transformers.py @@ -0,0 +1,122 @@ +from typing import Optional, List + +from ovos_plugin_manager.metadata_transformers import find_metadata_transformer_plugins +from ovos_plugin_manager.text_transformers import find_utterance_transformer_plugins + +from ovos_utils.json_helper import merge_dict +from ovos_utils.log import LOG + + +class UtteranceTransformersService: + + def __init__(self, bus, config=None): + self.config_core = config or {} + self.loaded_plugins = {} + self.has_loaded = False + self.bus = bus + self.config = self.config_core.get("utterance_transformers") or {} + self.load_plugins() + + def load_plugins(self): + for plug_name, plug in find_utterance_transformer_plugins().items(): + if plug_name in self.config: + # if disabled skip it + if not self.config[plug_name].get("active", True): + continue + try: + self.loaded_plugins[plug_name] = plug() + LOG.info(f"loaded utterance transformer plugin: {plug_name}") + except Exception as e: + LOG.error(e) + LOG.exception(f"Failed to load utterance transformer plugin: {plug_name}") + + @property + def plugins(self): + """ + Return loaded transformers in priority order, such that modules with a + higher `priority` rank are called first and changes from lower ranked + transformers are applied last + + A plugin of `priority` 1 will override any existing context keys and + will be the last to modify utterances` + """ + return sorted(self.loaded_plugins.values(), + key=lambda k: k.priority, reverse=True) + + def shutdown(self): + for module in self.plugins: + try: + module.shutdown() + except: + pass + + def transform(self, utterances: List[str], context: Optional[dict] = None): + context = context or {} + + for module in self.plugins: + try: + utterances, data = module.transform(utterances, context) + _safe = {k:v for k,v in data.items() if k != "session"} # no leaking TTS/STT creds in logs + LOG.debug(f"{module.name}: {_safe}") + context = merge_dict(context, data) + except Exception as e: + LOG.warning(f"{module.name} transform exception: {e}") + return utterances, context + + +class MetadataTransformersService: + + def __init__(self, bus, config=None): + self.config_core = config or {} + self.loaded_plugins = {} + self.has_loaded = False + self.bus = bus + self.config = self.config_core.get("metadata_transformers") or {} + self.load_plugins() + + def load_plugins(self): + for plug_name, plug in find_metadata_transformer_plugins().items(): + if plug_name in self.config: + # if disabled skip it + if not self.config[plug_name].get("active", True): + continue + try: + self.loaded_plugins[plug_name] = plug() + LOG.info(f"loaded metadata transformer plugin: {plug_name}") + except Exception as e: + LOG.error(e) + LOG.exception(f"Failed to load metadata transformer plugin: {plug_name}") + + @property + def plugins(self): + """ + Return loaded transformers in priority order, such that modules with a + higher `priority` rank are called first and changes from lower ranked + transformers are applied last. + + A plugin of `priority` 1 will override any existing context keys + """ + return sorted(self.loaded_plugins.values(), + key=lambda k: k.priority, reverse=True) + + def shutdown(self): + for module in self.plugins: + try: + module.shutdown() + except: + pass + + def transform(self, context: Optional[dict] = None): + context = context or {} + + for module in self.plugins: + try: + data = module.transform(context) + _safe = {k:v for k,v in data.items() if k != "session"} # no leaking TTS/STT creds in logs + LOG.debug(f"{module.name}: {_safe}") + context = merge_dict(context, data) + except Exception as e: + LOG.warning(f"{module.name} transform exception: {e}") + return context + + diff --git a/requirements.txt b/requirements.txt index fb116e4..063e39a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,8 @@ tornado ovos_utils>=0.0.33 pycryptodomex HiveMind_presence>=0.0.2a3 -ovos-bus-client>=0.0.6a5 +ovos-bus-client>=0.0.6 +ovos-plugin-manager poorman_handshake>=0.1.0 click click_default_group