Skip to content

Commit

Permalink
Custom hosts (#70)
Browse files Browse the repository at this point in the history
* Support custom deepgram hostnames.

This is useful to support self-hosted deployments or deepgram in-house
development of the SDK itself.

* Add options for constructing Deepgram with different hosts

* Redact api keys
  • Loading branch information
jcdyer authored Jul 8, 2024
1 parent d60f05b commit 641b751
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 38 deletions.
109 changes: 100 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,37 @@

use std::io;

use redacted::RedactedString;
use reqwest::{
header::{HeaderMap, HeaderValue},
RequestBuilder,
};
use serde::de::DeserializeOwned;
use thiserror::Error;
use url::Url;

pub mod billing;
pub mod invitations;
pub mod keys;
pub mod members;
pub mod projects;
mod redacted;
mod response;
pub mod scopes;
pub mod transcription;
pub mod usage;

mod response;
static DEEPGRAM_BASE_URL: &str = "https://api.deepgram.com";

/// A client for the Deepgram API.
///
/// Make transcriptions requests using [`Deepgram::transcription`].
#[derive(Debug, Clone)]
pub struct Deepgram {
#[cfg_attr(not(feature = "live"), allow(unused))]
api_key: String,
api_key: Option<RedactedString>,
#[cfg_attr(not(any(feature = "live", feature = "prerecorded")), allow(unused))]
base_url: Url,
client: reqwest::Client,
}

Expand Down Expand Up @@ -81,6 +87,8 @@ type Result<T> = std::result::Result<T, DeepgramError>;
impl Deepgram {
/// Construct a new Deepgram client.
///
/// The client will be pointed at Deepgram's hosted API.
///
/// Create your first API key on the [Deepgram Console][console].
///
/// [console]: https://console.deepgram.com/
Expand All @@ -89,6 +97,88 @@ impl Deepgram {
///
/// Panics under the same conditions as [`reqwest::Client::new`].
pub fn new<K: AsRef<str>>(api_key: K) -> Self {
let api_key = Some(api_key.as_ref().to_owned());
Self::inner_constructor(DEEPGRAM_BASE_URL.try_into().unwrap(), api_key)
}

/// Construct a new Deepgram client with the specified base URL.
///
/// When using a self-hosted instance of deepgram, this will be the
/// host portion of your own instance. For instance, if you would
/// query your deepgram instance at `http://deepgram.internal/v1/listen`,
/// the base_url will be `http://deepgram.internal`.
///
/// Admin features, such as billing, usage, and key management will
/// still go through the hosted site at `https://api.deepgram.com`.
///
/// Self-hosted instances do not in general authenticate incoming
/// requests, so unlike in [`Deepgram::new`], so no api key needs to be
/// provided. The SDK will not include an `Authorization` header in its
/// requests. If an API key is required, consider using
/// [`Deepgram::with_base_url_and_api_key`].
///
/// [console]: https://console.deepgram.com/
///
/// # Example:
///
/// ```
/// # use deepgram::Deepgram;
/// let deepgram = Deepgram::with_base_url(
/// "http://localhost:8080",
/// );
/// ```
///
/// # Panics
///
/// Panics under the same conditions as [`reqwest::Client::new`], or if `base_url`
/// is not a valid URL.
pub fn with_base_url<U>(base_url: U) -> Self
where
U: TryInto<Url>,
U::Error: std::fmt::Debug,
{
let base_url = base_url.try_into().expect("base_url must be a valid Url");
Self::inner_constructor(base_url, None)
}

/// Construct a new Deepgram client with the specified base URL and
/// API Key.
///
/// When using a self-hosted instance of deepgram, this will be the
/// host portion of your own instance. For instance, if you would
/// query your deepgram instance at `http://deepgram.internal/v1/listen`,
/// the base_url will be `http://deepgram.internal`.
///
/// Admin features, such as billing, usage, and key management will
/// still go through the hosted site at `https://api.deepgram.com`.
///
/// [console]: https://console.deepgram.com/
///
/// # Example:
///
/// ```
/// # use deepgram::Deepgram;
/// let deepgram = Deepgram::with_base_url_and_api_key(
/// "http://localhost:8080",
/// "apikey12345",
/// );
/// ```
///
/// # Panics
///
/// Panics under the same conditions as [`reqwest::Client::new`], or if `base_url`
/// is not a valid URL.
pub fn with_base_url_and_api_key<U, K>(base_url: U, api_key: K) -> Self
where
U: TryInto<Url>,
U::Error: std::fmt::Debug,
K: AsRef<str>,
{
let base_url = base_url.try_into().expect("base_url must be a valid Url");
Self::inner_constructor(base_url, Some(api_key.as_ref().to_owned()))
}

fn inner_constructor(base_url: Url, api_key: Option<String>) -> Self {
static USER_AGENT: &str = concat!(
env!("CARGO_PKG_NAME"),
"/",
Expand All @@ -98,17 +188,18 @@ impl Deepgram {

let authorization_header = {
let mut header = HeaderMap::new();
header.insert(
"Authorization",
HeaderValue::from_str(&format!("Token {}", api_key.as_ref()))
.expect("Invalid API key"),
);
if let Some(api_key) = &api_key {
header.insert(
"Authorization",
HeaderValue::from_str(&format!("Token {}", api_key)).expect("Invalid API key"),
);
}
header
};
let api_key = api_key.as_ref().to_owned();

Deepgram {
api_key,
api_key: api_key.map(RedactedString),
base_url,
client: reqwest::Client::builder()
.user_agent(USER_AGENT)
.default_headers(authorization_header)
Expand Down
18 changes: 18 additions & 0 deletions src/redacted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use std::{fmt, ops::Deref};

#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct RedactedString(pub String);

impl fmt::Debug for RedactedString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("***")
}
}

impl Deref for RedactedString {
type Target = str;

fn deref(&self) -> &Self::Target {
&self.0
}
}
90 changes: 63 additions & 27 deletions src/transcription/live.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ use crate::{Deepgram, DeepgramError, Result};

use super::Transcription;

static LIVE_LISTEN_URL_PATH: &str = "v1/listen";

#[derive(Debug)]
pub struct StreamRequestBuilder<'a, S, E>
where
Expand All @@ -40,6 +42,7 @@ where
encoding: Option<String>,
sample_rate: Option<u32>,
channels: Option<u16>,
stream_url: Url,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -95,7 +98,18 @@ impl Transcription<'_> {
encoding: None,
sample_rate: None,
channels: None,
stream_url: self.listen_stream_url(),
}
}

fn listen_stream_url(&self) -> Url {
let mut url = self.0.base_url.join(LIVE_LISTEN_URL_PATH).unwrap();
match url.scheme() {
"http" | "ws" => url.set_scheme("ws").unwrap(),
"https" | "wss" => url.set_scheme("wss").unwrap(),
_ => panic!("base_url must have a scheme of http, https, ws, or wss"),
}
url
}
}

Expand Down Expand Up @@ -201,42 +215,43 @@ where
E: Send + std::fmt::Debug,
{
pub async fn start(self) -> Result<Receiver<Result<StreamResponse>>> {
let StreamRequestBuilder {
config,
source,
encoding,
sample_rate,
channels,
} = self;
let mut source = source
.ok_or(DeepgramError::NoSource)?
.map(|res| res.map(|bytes| Message::binary(Vec::from(bytes.as_ref()))));

// This unwrap is safe because we're parsing a static.
let mut base = Url::parse("wss://api.deepgram.com/v1/listen").unwrap();
let mut url = self.stream_url;
{
let mut pairs = base.query_pairs_mut();
if let Some(encoding) = encoding {
pairs.append_pair("encoding", &encoding);
let mut pairs = url.query_pairs_mut();
if let Some(encoding) = &self.encoding {
pairs.append_pair("encoding", encoding);
}
if let Some(sample_rate) = sample_rate {
if let Some(sample_rate) = self.sample_rate {
pairs.append_pair("sample_rate", &sample_rate.to_string());
}
if let Some(channels) = channels {
if let Some(channels) = self.channels {
pairs.append_pair("channels", &channels.to_string());
}
}

let request = Request::builder()
.method("GET")
.uri(base.to_string())
.header("authorization", format!("token {}", config.api_key))
.header("sec-websocket-key", client::generate_key())
.header("host", "api.deepgram.com")
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.body(())?;
let mut source = self
.source
.ok_or(DeepgramError::NoSource)?
.map(|res| res.map(|bytes| Message::binary(Vec::from(bytes.as_ref()))));

let request = {
let builder = Request::builder()
.method("GET")
.uri(url.to_string())
.header("sec-websocket-key", client::generate_key())
.header("host", "api.deepgram.com")
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13");

let builder = if let Some(api_key) = self.config.api_key.as_deref() {
builder.header("authorization", format!("token {}", api_key))
} else {
builder
};
builder.body(())?
};
let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?;
let (mut write, mut read) = ws_stream.split();
let (mut tx, rx) = mpsc::channel::<Result<StreamResponse>>(1);
Expand Down Expand Up @@ -288,3 +303,24 @@ where
Ok(rx)
}
}

#[cfg(test)]
mod tests {
#[test]
fn test_stream_url() {
let dg = crate::Deepgram::new("token");
assert_eq!(
dg.transcription().listen_stream_url().to_string(),
"wss://api.deepgram.com/v1/listen",
);
}

#[test]
fn test_stream_url_custom_host() {
let dg = crate::Deepgram::with_base_url_and_api_key("http://localhost:8080", "token");
assert_eq!(
dg.transcription().listen_stream_url().to_string(),
"ws://localhost:8080/v1/listen",
);
}
}
32 changes: 30 additions & 2 deletions src/transcription/prerecorded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! [api]: https://developers.deepgram.com/api-reference/#transcription-prerecorded

use reqwest::RequestBuilder;
use url::Url;

use super::Transcription;
use crate::send_and_translate_response;
Expand All @@ -17,7 +18,7 @@ use audio_source::AudioSource;
use options::{Options, SerializableOptions};
use response::{CallbackResponse, Response};

static DEEPGRAM_API_URL_LISTEN: &str = "https://api.deepgram.com/v1/listen";
static DEEPGRAM_API_URL_LISTEN: &str = "v1/listen";

impl Transcription<'_> {
/// Sends a request to Deepgram to transcribe pre-recorded audio.
Expand Down Expand Up @@ -195,7 +196,7 @@ impl Transcription<'_> {
let request_builder = self
.0
.client
.post(DEEPGRAM_API_URL_LISTEN)
.post(self.listen_url())
.query(&SerializableOptions(options));

source.fill_body(request_builder)
Expand Down Expand Up @@ -267,4 +268,31 @@ impl Transcription<'_> {
self.make_prerecorded_request_builder(source, options)
.query(&[("callback", callback)])
}

fn listen_url(&self) -> Url {
self.0.base_url.join(DEEPGRAM_API_URL_LISTEN).unwrap()
}
}

#[cfg(test)]
mod tests {
use crate::Deepgram;

#[test]
fn listen_url() {
let dg = Deepgram::new("token");
assert_eq!(
&dg.transcription().listen_url().to_string(),
"https://api.deepgram.com/v1/listen"
);
}

#[test]
fn listen_url_custom_host() {
let dg = Deepgram::with_base_url("http://localhost:8888/abc/");
assert_eq!(
&dg.transcription().listen_url().to_string(),
"http://localhost:8888/abc/v1/listen"
);
}
}

0 comments on commit 641b751

Please sign in to comment.