Skip to content

Commit

Permalink
add extra headers to cli to be able to be used by the Websocket onCon…
Browse files Browse the repository at this point in the history
…nect (#99)

* add cli option for extra headers to be added on the websocket client

---------

Co-authored-by: Georgios Hadjiharalambous <[email protected]>
  • Loading branch information
giorgosHadji and Georgios Hadjiharalambous authored Jun 18, 2024
1 parent 63ba444 commit 1e736f0
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 6 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.14.9] - 2024-06-14

### Added
- Support for adding extra headers for RT websocket

## [1.14.8] - 2024-05-14

### Changed
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.14.8
1.14.9
7 changes: 6 additions & 1 deletion speechmatics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def rt_main(args):
transcription_config = get_transcription_config(args)
settings = get_connection_settings(args, lang=transcription_config.language)
api = WebsocketClient(settings)
extra_headers = args.get("extra_headers")

if settings.url.lower().startswith("ws://") and args["ssl_mode"] != "none":
raise SystemExit(
Expand Down Expand Up @@ -677,7 +678,11 @@ def rt_main(args):
def run(stream):
try:
api.run_synchronously(
stream, transcription_config, get_audio_settings(args), from_cli=True
stream,
transcription_config,
get_audio_settings(args),
from_cli=True,
extra_headers=extra_headers,
)
except KeyboardInterrupt:
# Gracefully handle Ctrl-C, else we get a huge stack-trace.
Expand Down
28 changes: 28 additions & 0 deletions speechmatics/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,25 @@
LOGGER = logging.getLogger(__name__)


class kvdictAppendAction(argparse.Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary.
"""

def __call__(self, parser, args, values, option_string=None):
for pair in values:
try:
(k, v) = pair.split("=", 2)
except ValueError:
raise argparse.ArgumentError(
self, f'could not parse argument "{pair}" as k=v format'
)
d = getattr(args, self.dest) or {}
d[k] = v
setattr(args, self.dest, d)


def additional_vocab_item(to_parse):
"""
Parses a single item of additional vocab. Used in conjunction with the
Expand Down Expand Up @@ -493,6 +512,15 @@ def get_arg_parser():
required=False,
help="Removes words tagged as disfluency.",
)
rt_transcribe_command_parser.add_argument(
"--extra-headers",
default=dict(),
nargs="+",
action=kvdictAppendAction,
metavar="KEY=VALUE",
required=False,
help="Adds extra headers to the websocket client",
)

# Parent parser for batch auto-chapters argument
batch_audio_events_parser = argparse.ArgumentParser(add_help=False)
Expand Down
8 changes: 4 additions & 4 deletions speechmatics/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import logging
import os
from typing import Union
from typing import Dict, Union
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

import httpx
Expand Down Expand Up @@ -423,8 +423,9 @@ async def run(
self,
stream,
transcription_config: TranscriptionConfig,
audio_settings=AudioSettings(),
from_cli=False,
audio_settings: AudioSettings = AudioSettings(),
from_cli: bool = False,
extra_headers: Dict = dict(),
):
"""
Begin a new recognition session.
Expand All @@ -451,7 +452,6 @@ async def run(
self.seq_no = 0
self._language_pack_info = None
await self._init_synchronization_primitives()
extra_headers = {}
if (
not self.connection_settings.generate_temp_token
and self.connection_settings.auth_token is not None
Expand Down
15 changes: 15 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@
["rt", "transcribe", "--additional-vocab", "Speechmatics", "gnocchi"],
{"additional_vocab": ["Speechmatics", "gnocchi"]},
),
(
[
"rt",
"transcribe",
"--extra-headers",
"magic_header=magic_value",
"another_magic_header=another_magic_value",
],
{
"extra_headers": {
"magic_header": "magic_value",
"another_magic_header": "another_magic_value",
}
},
),
(
["batch", "transcribe", "--additional-vocab", "Speechmatics", "gnocchi"],
{"additional_vocab": ["Speechmatics", "gnocchi"]},
Expand Down
27 changes: 27 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,33 @@ async def mock_connect(*_, **__):
assert exc is not None


def test_extra_headers_are_passed_to_websocket_connect_correctly(mock_server):
"""Tests extra headers are passed correclty to the websocket onConnect call."""
extra_headers = {"keyy": "value"}

def call_exit(*args, **kwargs):
raise Exception()

connect_mock = MagicMock(side_effect=call_exit)
ws_client, transcription_config, audio_settings = default_ws_client_setup(
mock_server.url
)
stream = MagicMock()
with patch("websockets.connect", connect_mock):
try:
ws_client.run_synchronously(
stream,
transcription_config,
audio_settings,
extra_headers=extra_headers,
)
except Exception:
assert len(connect_mock.mock_calls) == 1
assert (
connect_mock.mock_calls[0][2]["extra_headers"] == extra_headers
), f"Extra headers don't appear in the call list = {connect_mock.mock_calls}"


@pytest.mark.asyncio
async def test__buffer_semaphore():
"""Test the WebsocketClient internal BoundedSemaphore."""
Expand Down

0 comments on commit 1e736f0

Please sign in to comment.