diff --git a/CHANGELOG.md b/CHANGELOG.md index 05f21ab..f6ebb78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.14.9] - 2024-06-14 + +### Added +- Support for adding extra headers for RT websocket + ## [1.14.8] - 2024-05-14 ### Changed diff --git a/VERSION b/VERSION index 9be7846..0b94c5f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.14.8 +1.14.9 diff --git a/speechmatics/cli.py b/speechmatics/cli.py index 7bab565..8bd9c24 100755 --- a/speechmatics/cli.py +++ b/speechmatics/cli.py @@ -650,6 +650,7 @@ def rt_main(args): transcription_config = get_transcription_config(args) settings = get_connection_settings(args, lang=transcription_config.language) api = WebsocketClient(settings) + extra_headers = args.get("extra_headers") if settings.url.lower().startswith("ws://") and args["ssl_mode"] != "none": raise SystemExit( @@ -677,7 +678,11 @@ def rt_main(args): def run(stream): try: api.run_synchronously( - stream, transcription_config, get_audio_settings(args), from_cli=True + stream, + transcription_config, + get_audio_settings(args), + from_cli=True, + extra_headers=extra_headers, ) except KeyboardInterrupt: # Gracefully handle Ctrl-C, else we get a huge stack-trace. diff --git a/speechmatics/cli_parser.py b/speechmatics/cli_parser.py index c91c375..7839077 100644 --- a/speechmatics/cli_parser.py +++ b/speechmatics/cli_parser.py @@ -9,6 +9,25 @@ LOGGER = logging.getLogger(__name__) +class kvdictAppendAction(argparse.Action): + """ + argparse action to split an argument into KEY=VALUE form + on the first = and append to a dictionary. + """ + + def __call__(self, parser, args, values, option_string=None): + for pair in values: + try: + (k, v) = pair.split("=", 2) + except ValueError: + raise argparse.ArgumentError( + self, f'could not parse argument "{pair}" as k=v format' + ) + d = getattr(args, self.dest) or {} + d[k] = v + setattr(args, self.dest, d) + + def additional_vocab_item(to_parse): """ Parses a single item of additional vocab. Used in conjunction with the @@ -493,6 +512,15 @@ def get_arg_parser(): required=False, help="Removes words tagged as disfluency.", ) + rt_transcribe_command_parser.add_argument( + "--extra-headers", + default=dict(), + nargs="+", + action=kvdictAppendAction, + metavar="KEY=VALUE", + required=False, + help="Adds extra headers to the websocket client", + ) # Parent parser for batch auto-chapters argument batch_audio_events_parser = argparse.ArgumentParser(add_help=False) diff --git a/speechmatics/client.py b/speechmatics/client.py index 43e3ebe..8aca612 100644 --- a/speechmatics/client.py +++ b/speechmatics/client.py @@ -9,7 +9,7 @@ import json import logging import os -from typing import Union +from typing import Dict, Union from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse import httpx @@ -423,8 +423,9 @@ async def run( self, stream, transcription_config: TranscriptionConfig, - audio_settings=AudioSettings(), - from_cli=False, + audio_settings: AudioSettings = AudioSettings(), + from_cli: bool = False, + extra_headers: Dict = dict(), ): """ Begin a new recognition session. @@ -451,7 +452,6 @@ async def run( self.seq_no = 0 self._language_pack_info = None await self._init_synchronization_primitives() - extra_headers = {} if ( not self.connection_settings.generate_temp_token and self.connection_settings.auth_token is not None diff --git a/tests/test_cli.py b/tests/test_cli.py index ce22044..c77e592 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -60,6 +60,21 @@ ["rt", "transcribe", "--additional-vocab", "Speechmatics", "gnocchi"], {"additional_vocab": ["Speechmatics", "gnocchi"]}, ), + ( + [ + "rt", + "transcribe", + "--extra-headers", + "magic_header=magic_value", + "another_magic_header=another_magic_value", + ], + { + "extra_headers": { + "magic_header": "magic_value", + "another_magic_header": "another_magic_value", + } + }, + ), ( ["batch", "transcribe", "--additional-vocab", "Speechmatics", "gnocchi"], {"additional_vocab": ["Speechmatics", "gnocchi"]}, diff --git a/tests/test_client.py b/tests/test_client.py index 18d1e7d..bcd5f64 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -321,6 +321,33 @@ async def mock_connect(*_, **__): assert exc is not None +def test_extra_headers_are_passed_to_websocket_connect_correctly(mock_server): + """Tests extra headers are passed correclty to the websocket onConnect call.""" + extra_headers = {"keyy": "value"} + + def call_exit(*args, **kwargs): + raise Exception() + + connect_mock = MagicMock(side_effect=call_exit) + ws_client, transcription_config, audio_settings = default_ws_client_setup( + mock_server.url + ) + stream = MagicMock() + with patch("websockets.connect", connect_mock): + try: + ws_client.run_synchronously( + stream, + transcription_config, + audio_settings, + extra_headers=extra_headers, + ) + except Exception: + assert len(connect_mock.mock_calls) == 1 + assert ( + connect_mock.mock_calls[0][2]["extra_headers"] == extra_headers + ), f"Extra headers don't appear in the call list = {connect_mock.mock_calls}" + + @pytest.mark.asyncio async def test__buffer_semaphore(): """Test the WebsocketClient internal BoundedSemaphore."""