Skip to content
This repository has been archived by the owner on Sep 21, 2024. It is now read-only.

Commit

Permalink
feat: provide custom host injection in client
Browse files Browse the repository at this point in the history
  • Loading branch information
jsantell committed Nov 28, 2023
1 parent 60e9a3a commit 0def4d6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 73 deletions.
94 changes: 21 additions & 73 deletions rust/noosphere-core/src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ where

client: reqwest::Client,

#[cfg(feature = "test-gateway")]
forced_host_header: reqwest::header::HeaderValue,
host_header: Option<reqwest::header::HeaderValue>,
}

impl<K, S> Client<K, S>
Expand All @@ -88,15 +87,18 @@ where
author: &Author<K>,
did_parser: &mut DidParser,
store: S,
override_host: Option<String>,
) -> Result<Client<K, S>> {
debug!("Initializing Noosphere API client");
debug!("Client represents sphere {}", sphere_identity);
debug!("Client targetting API at {}", api_base);

let client = reqwest::Client::new();

#[cfg(feature = "test-gateway")]
let forced_host_header = create_test_header(api_base, &Did::from(sphere_identity))?;
let override_host_header = match override_host {
Some(host) => Some(reqwest::header::HeaderValue::from_str(&host)?),
None => None,
};

let did_response = {
let mut url = api_base.clone();
Expand All @@ -105,9 +107,8 @@ where
#[allow(unused_mut)]
let mut client = client.get(url);

#[cfg(feature = "test-gateway")]
{
client = client.header(reqwest::header::HOST, &forced_host_header);
if let Some(host_header) = override_host_header.as_ref() {
client = client.header(reqwest::header::HOST, host_header);
}

client.send().await?
Expand All @@ -132,8 +133,9 @@ where
)
.await?;

#[cfg(feature = "test-gateway")]
apply_test_header(&mut headers, &forced_host_header);
if let Some(host_header) = override_host_header.as_ref() {
headers.insert(reqwest::header::HOST, host_header.to_owned());
}

let identify_response: v0alpha1::IdentifyResponse = client
.get(url)
Expand All @@ -158,8 +160,7 @@ where
author: author.clone(),
store,
client,
#[cfg(feature = "test-gateway")]
forced_host_header,
host_header: override_host_header,
})
}

Expand Down Expand Up @@ -274,8 +275,9 @@ where
)
.await?;

#[cfg(feature = "test-gateway")]
apply_test_header(&mut headers, &self.forced_host_header);
if let Some(host_header) = self.host_header.as_ref() {
headers.insert(reqwest::header::HOST, host_header.to_owned());
}

let response = self
.client
Expand Down Expand Up @@ -339,8 +341,9 @@ where
)
.await?;

#[cfg(feature = "test-gateway")]
apply_test_header(&mut headers, &self.forced_host_header);
if let Some(host_header) = self.host_header.as_ref() {
headers.insert(reqwest::header::HOST, host_header.to_owned());
}

let response = self
.client
Expand Down Expand Up @@ -543,8 +546,9 @@ where
)
.await?;

#[cfg(feature = "test-gateway")]
apply_test_header(&mut headers, &self.forced_host_header);
if let Some(host_header) = self.host_header.as_ref() {
headers.insert(reqwest::header::HOST, host_header.to_owned());
}

let block_stream = self
.make_push_request(url, headers, &token, push_body)
Expand All @@ -567,59 +571,3 @@ where
Ok(push_response)
}
}

#[cfg(feature = "test-gateway")]
fn apply_test_header(headers: &mut HeaderMap, forced_host_header: &reqwest::header::HeaderValue) {
use reqwest::header::HOST;
_ = headers.remove(HOST);
headers.insert(HOST, forced_host_header.to_owned());
}

#[cfg(feature = "test-gateway")]
fn create_test_header(api_base: &Url, identity: &Did) -> Result<reqwest::header::HeaderValue> {
let mod_identity = identity
.as_str()
.strip_prefix("did:key:")
.ok_or_else(|| anyhow!("Could not format Host header for test-gateway."))?;
let domain = api_base
.domain()
.ok_or_else(|| anyhow!("Host header does not have domain."))?;
let port = api_base.port();

let new_host = if let Some(port) = port {
format!("{}.{}:{}", mod_identity, domain, port)
} else {
format!("{}.{}", mod_identity, domain)
};

Ok(reqwest::header::HeaderValue::from_str(&new_host)?)
}

#[cfg(all(test, feature = "test-gateway"))]
mod tests {
use super::*;
use reqwest::header::HeaderValue;

#[test]
fn it_creates_test_header_from_url() -> Result<()> {
let identity = Did::from("did:key:z6Mkuj9KHUDzGng3rKPouDgnrJJAk9DiBLRL7nWV4ULMs4E7");
let mod_id = "z6Mkuj9KHUDzGng3rKPouDgnrJJAk9DiBLRL7nWV4ULMs4E7";
let expectations = [
("http://localhost", format!("{mod_id}.localhost")),
("http://localhost:1234", format!("{mod_id}.localhost:1234")),
("http://foo.bar", format!("{mod_id}.foo.bar")),
("http://foo.bar:1234", format!("{mod_id}.foo.bar:1234")),
];

for (api_base, expected_host) in expectations {
assert_eq!(
create_test_header(&Url::parse(api_base)?, &identity)?,
HeaderValue::from_str(&expected_host)?
);
}

assert!(create_test_header(&Url::parse("http://127.0.0.1")?, &identity).is_err());
assert!(create_test_header(&Url::parse("http://127.0.0.1:1234")?, &identity).is_err());
Ok(())
}
}
26 changes: 26 additions & 0 deletions rust/noosphere-core/src/context/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use tokio::sync::OnceCell;
use ucan::crypto::{did::DidParser, KeyMaterial};
use url::Url;

#[cfg(feature = "test-gateway")]
const GATEWAY_OVERRIDE_HOST: &str = "gateway_override_host";

#[cfg(doc)]
use crate::context::has::HasSphereContext;

Expand Down Expand Up @@ -224,6 +227,10 @@ where
.client
.get_or_try_init::<anyhow::Error, _, _>(|| async {
let gateway_url: Url = self.db.require_key(GATEWAY_URL).await?;
#[cfg(feature = "test-gateway")]
let host_header = { self.db.get_key(GATEWAY_OVERRIDE_HOST).await? };
#[cfg(not(feature = "test-gateway"))]
let host_header = { None };

Ok(Arc::new(
Client::identify(
Expand All @@ -233,6 +240,7 @@ where
// TODO: Kill `DidParser` with fire
&mut DidParser::new(SUPPORTED_KEYS),
self.db.clone(),
host_header,
)
.await?,
))
Expand All @@ -247,6 +255,24 @@ where
pub(crate) fn reset_access(&mut self) {
self.access.take();
}

#[cfg(feature = "test-gateway")]
pub async fn configure_gateway_host(&mut self, host: Option<&str>) -> Result<()> {
self.client = OnceCell::new();

match host {
Some(host) => {
self.db
.set_key(GATEWAY_OVERRIDE_HOST, host.to_owned())
.await?;
}
None => {
self.db.unset_key(GATEWAY_OVERRIDE_HOST).await?;
}
}

Ok(())
}
}

#[cfg(test)]
Expand Down

0 comments on commit 0def4d6

Please sign in to comment.