Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse digest auth challenge #131

Merged
merged 4 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions onvif/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,37 @@ license = "MIT"
tls = ["reqwest/native-tls"]

[dependencies]
async-recursion = "0.3.1"
async-trait = "0.1.41"
base64 = "0.13.0"
bigdecimal = "0.3.0"
chrono = "0.4.19"
digest_auth = "0.3.0"
futures = "0.3.30"
futures-core = "0.3.8"
futures-util = "0.3.30"
num-bigint = "0.4.2"
reqwest = { version = "0.12.3", default-features = false }
async-recursion = "0.3"
async-trait = "0.1"
base64 = "0.13"
bigdecimal = "0.3"
chrono = "0.4"
digest_auth = "0.3"
futures = "0.3"
futures-core = "0.3"
futures-util = "0.3"
num-bigint = "0.4"
nonzero_ext = "0.3"
reqwest = { version = "0.12", default-features = false }
schema = { version = "0.1.0", path = "../schema", default-features = false, features = ["analytics", "devicemgmt", "event", "media", "ptz"] }
sha1 = "0.6.0"
sha1 = "0.6"
thiserror = "1.0"
tokio = { version = "1", default-features = false, features = ["net", "sync", "time"] }
tokio-stream = "0.1"
tracing = "0.1.26"
url = "2.2.0"
uuid = { version = "0.8.1", features = ["v4"] }
xml-rs = "=0.8.3"
xmltree = "0.10.2"
tracing = "0.1"
url = "2"
uuid = { version = "1", features = ["v4"] }
xml-rs = "0.8"
xmltree = "0.10"
xsd-macro-utils = { git = "https://github.com/lumeohq/xsd-parser-rs", rev = "7f3d433" }
xsd-types = { git = "https://github.com/lumeohq/xsd-parser-rs", rev = "7f3d433" }
yaserde = "0.7.1"
yaserde_derive = "0.7.1"
yaserde = "0.7"
yaserde_derive = "0.7"

[dev-dependencies]
dotenv = "0.15.0"
futures-util = "0.3.8"
structopt = "0.3.21"
tokio = { version = "1.0.1", features = ["full"] }
tracing-subscriber = "0.2.20"
dotenv = "0.15"
futures-util = "0.3"
structopt = "0.3"
tokio = { version = "1", features = ["full"] }
tracing-subscriber = "0.2"
b_2 = { path = "../wsdl_rs/b_2" }
59 changes: 45 additions & 14 deletions onvif/src/soap/auth/digest.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::soap::client::Credentials;
use nonzero_ext::nonzero;
use reqwest::{RequestBuilder, Response};
use std::fmt::{Debug, Formatter};
use std::num::NonZeroU8;
use thiserror::Error;
use url::Url;

Expand All @@ -22,8 +24,10 @@ pub struct Digest {

enum State {
Default,
Got401(reqwest::Response),
Got401Twice,
Got401 {
response: Response,
count: NonZeroU8,
},
}

impl Digest {
Expand All @@ -37,29 +41,55 @@ impl Digest {
}

impl Digest {
/// Call this when the authentication was successful.
pub fn set_success(&mut self) {
if let State::Got401 { count, .. } = &mut self.state {
// We always store at least one request, so it's never zero.
*count = nonzero!(1_u8);
}
}

/// Call this when received 401 Unauthorized.
pub fn set_401(&mut self, response: Response) {
match self.state {
State::Default => self.state = State::Got401(response),
State::Got401(_) => self.state = State::Got401Twice,
State::Got401Twice => {}
self.state = match self.state {
State::Default => State::Got401 {
response,
count: nonzero!(1_u8),
},
State::Got401 { count, .. } => State::Got401 {
response,
count: count.saturating_add(1),
},
}
}

pub fn is_failed(&self) -> bool {
matches!(self.state, State::Got401Twice)
match &self.state {
State::Default => false,
// Possible scenarios:
// - We've got 401 with a challenge for the first time, we calculate the answer, then
// we get 200 OK. So, a single 401 is never a failure.
// - After successful auth the count is 1 because we always store at least one request,
// and the caller decided to reuse the same challenge for multiple requests. But at
// some point, we'll get a 401 with a new challenge and `stale=true`.
// So, we'll get a second 401, and this is also not a failure because after
// calculating the answer to the challenge, we'll get a 200 OK, and will reset the
// counter in `set_success()`.
// - Three 401's in a row is certainly a failure.
FSMaxB marked this conversation as resolved.
Show resolved Hide resolved
State::Got401 { count, .. } => count.get() >= 3,
}
}

pub fn add_headers(&self, mut request: RequestBuilder) -> Result<RequestBuilder, Error> {
match &self.state {
State::Default => Ok(request),
State::Got401(response) => {
State::Got401 { response, .. } => {
let creds = self.creds.as_ref().ok_or(Error::NoCredentials)?;

request = request.header("Authorization", digest_auth(response, creds, &self.uri)?);

Ok(request)
}
State::Got401Twice => Err(Error::InvalidState),
}
}
}
Expand Down Expand Up @@ -94,10 +124,11 @@ impl Debug for Digest {

impl Debug for State {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
State::Default => "FirstRequest",
State::Got401(_) => "Got401",
State::Got401Twice => "Got401Twice",
})
match self {
State::Default => write!(f, "FirstRequest")?,
State::Got401 { count, .. } => write!(f, "Got401({count})")?,
};

Ok(())
}
}
63 changes: 30 additions & 33 deletions onvif/src/soap/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,22 @@ use crate::soap::{
};
use async_recursion::async_recursion;
use async_trait::async_trait;
use futures_util::lock::Mutex;
use schema::transport::{Error, Transport};
use std::ops::DerefMut;
use std::{
fmt::{Debug, Formatter},
sync::Arc,
time::Duration,
};
use tracing::{debug, instrument, trace};
use url::Url;

macro_rules! event {
($lvl:expr, $self:ident, $($arg:tt)+) => {
tracing::event!($lvl, "{}: {}", $self.config.uri, format_args!($($arg)+))
};
}

macro_rules! debug {
($($arg:tt)+) => {
event!(tracing::Level::DEBUG, $($arg)+)
}
}

#[derive(Clone)]
pub struct Client {
client: reqwest::Client,
config: Config,
digest_auth_state: Arc<Mutex<Digest>>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -95,9 +87,12 @@ impl ClientBuilder {
.unwrap()
};

let digest = Digest::new(&self.config.uri, &self.config.credentials);

Client {
client,
config: self.config,
digest_auth_state: Arc::new(Mutex::new(digest)),
}
}

Expand Down Expand Up @@ -155,20 +150,21 @@ impl Debug for Credentials {
pub type ResponsePatcher = Arc<dyn Fn(&str) -> Result<String, String> + Send + Sync>;

#[derive(Debug)]
enum RequestAuthType {
Digest(Digest),
enum RequestAuthType<'a> {
Digest(&'a mut Digest),
UsernameToken,
}

#[async_trait]
impl Transport for Client {
#[instrument(skip_all, fields(uri = self.config.uri.as_str()))]
async fn request(&self, message: &str) -> Result<String, Error> {
match self.config.auth_type {
AuthType::Any => {
match self.request_with_digest(message).await {
Ok(success) => Ok(success),
Err(Error::Authorization(e)) => {
debug!(self, "Failed to authorize with Digest auth: {}. Trying UsernameToken auth ...", e);
debug!("Failed to authorize with Digest auth: {e}. Trying UsernameToken auth ...");
self.request_with_username_token(message).await
}
Err(e) => Err(e),
Expand All @@ -182,8 +178,8 @@ impl Transport for Client {

impl Client {
async fn request_with_digest(&self, message: &str) -> Result<String, Error> {
let mut auth_type =
RequestAuthType::Digest(Digest::new(&self.config.uri, &self.config.credentials));
let mut guard = self.digest_auth_state.lock().await;
let mut auth_type = RequestAuthType::Digest(guard.deref_mut());

self.request_recursive(message, &self.config.uri, &mut auth_type, 0)
.await
Expand All @@ -209,13 +205,10 @@ impl Client {
_ => None,
};

debug!(
self,
"About to make request. auth_type={:?}, redirections={}", auth_type, redirections
);
debug!(?auth_type, %redirections, "About to make request.");

let soap_msg = soap::soap(message, &username_token)
.map_err(|e| Error::Protocol(format!("{:?}", e)))?;
let soap_msg =
soap::soap(message, &username_token).map_err(|e| Error::Protocol(format!("{e:?}")))?;

let mut request = self
.client
Expand All @@ -227,10 +220,10 @@ impl Client {
.add_headers(request)
.map_err(|e| Error::Authorization(e.to_string()))?;

debug!(self, "Digest headers added");
debug!("Digest headers added");
}

debug!(self, "Request body: {}", soap_msg);
trace!("Request body: {soap_msg}");

let response = request.body(soap_msg).send().await.map_err(|e| match e {
e if e.is_connect() => Error::Connection(e.to_string()),
Expand All @@ -242,24 +235,28 @@ impl Client {

let status = response.status();

debug!(self, "Response status: {}", status);
debug!("Response status: {status}");

if status.is_success() {
if let RequestAuthType::Digest(digest) = auth_type {
digest.set_success();
}

response
.text()
.await
.map_err(|e| Error::Protocol(e.to_string()))
.and_then(|text| {
debug!(self, "Response body: {}", text);
trace!("Response body: {text}");
let response =
soap::unsoap(&text).map_err(|e| Error::Protocol(format!("{:?}", e)))?;
soap::unsoap(&text).map_err(|e| Error::Protocol(format!("{e:?}")))?;
if let Some(response_patcher) = &self.config.response_patcher {
match response_patcher(&response) {
Ok(patched) => {
debug!(self, "Response (SOAP unwrapped, patched): {}", patched);
trace!("Response (SOAP unwrapped, patched): {patched}");
Ok(patched)
}
Err(e) => Err(Error::Protocol(format!("Patching failed: {}", e))),
Err(e) => Err(Error::Protocol(format!("Patching failed: {e}"))),
}
} else {
Ok(response)
Expand All @@ -272,7 +269,7 @@ impl Client {
}
_ => {
if let Ok(text) = response.text().await {
debug!(self, "Got Unauthorized with body: {}", text);
trace!("Got Unauthorized with body: {text}");
}

return Err(Error::Authorization("Unauthorized".to_string()));
Expand All @@ -291,13 +288,13 @@ impl Client {

let new_url = Client::get_redirect_location(&response)?;

debug!(self, "Redirecting to {} ...", new_url);
debug!("Redirecting to {new_url} ...");

self.request_recursive(message, &new_url, auth_type, redirections + 1)
.await
} else {
if let Ok(text) = response.text().await {
debug!(self, "Got HTTP error with body: {}", text);
trace!("Got HTTP error with body: {text}");
if let Err(soap::Error::Fault(f)) = soap::unsoap(&text) {
if f.is_unauthorized() {
return Err(Error::Authorization("Unauthorized".to_string()));
Expand Down
9 changes: 3 additions & 6 deletions onvif/src/soap/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ fn test_soap() {
</my:Book>
"#;

let expected = r#"
<?xml version="1.0" encoding="UTF-8"?>
let expected = r#"<?xml version="1.0" encoding="UTF-8"?>
<s:Envelope xmlns:s="http://www.w3.org/2003/05/soap-envelope"
xmlns:my="http://www.example.my/schema">
<s:Body>
Expand Down Expand Up @@ -44,8 +43,7 @@ fn test_unsoap() {
pub pages: i32,
}

let input = r#"
<?xml version="1.0" encoding="utf-8"?>
let input = r#"<?xml version="1.0" encoding="utf-8"?>
<s:Envelope xmlns:s="http://www.w3.org/2003/05/soap-envelope"
xmlns:my="http://www.example.my/schema">
<s:Body>
Expand All @@ -69,8 +67,7 @@ fn test_unsoap() {

#[test]
fn test_get_fault() {
let response = r#"
<?xml version="1.0" ?>
let response = r#"<?xml version="1.0" ?>
<soapenv:Fault
xmlns:soapenv="http://www.w3.org/2003/05/soap-envelope"
xmlns:ter="http://www.onvif.org/ver10/error"
Expand Down
Loading