Skip to content

Commit

Permalink
feat(sdk)!: allow setting CA cert (#1924)
Browse files Browse the repository at this point in the history
  • Loading branch information
lklimek authored Dec 17, 2024
1 parent 82a6217 commit 8185d21
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 27 additions & 5 deletions packages/rs-dapi-client/src/dapi_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use backon::{ConstantBuilder, Retryable};
use dapi_grpc::mock::Mockable;
use dapi_grpc::tonic::async_trait;
use dapi_grpc::tonic::transport::Certificate;
use std::fmt::{Debug, Display};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
Expand Down Expand Up @@ -76,6 +77,8 @@ pub struct DapiClient {
address_list: AddressList,
settings: RequestSettings,
pool: ConnectionPool,
/// Certificate Authority certificate to use for verifying the server's certificate.
pub ca_certificate: Option<Certificate>,
#[cfg(feature = "dump")]
pub(crate) dump_dir: Option<std::path::PathBuf>,
}
Expand All @@ -92,9 +95,24 @@ impl DapiClient {
pool: ConnectionPool::new(address_count),
#[cfg(feature = "dump")]
dump_dir: None,
ca_certificate: None,
}
}

/// Set CA certificate to use when verifying the server's certificate.
///
/// # Arguments
///
/// * `pem_ca_cert` - CA certificate in PEM format.
///
/// # Returns
/// [DapiClient] with CA certificate set.
pub fn with_ca_certificate(mut self, ca_cert: Certificate) -> Self {
self.ca_certificate = Some(ca_cert);

self
}

/// Return the [DapiClient] address list.
pub fn address_list(&self) -> &AddressList {
&self.address_list
Expand Down Expand Up @@ -182,7 +200,8 @@ impl DapiRequestExecutor for DapiClient {
.settings
.override_by(R::SETTINGS_OVERRIDES)
.override_by(settings)
.finalize();
.finalize()
.with_ca_certificate(self.ca_certificate.clone());

// Setup retry policy:
let retry_settings = ConstantBuilder::default()
Expand All @@ -198,6 +217,9 @@ impl DapiRequestExecutor for DapiClient {
let retries_counter_arc = Arc::new(AtomicUsize::new(0));
let retries_counter_arc_ref = &retries_counter_arc;

// We need reference so that the closure is FnMut
let applied_settings_ref = &applied_settings;

// Setup DAPI request execution routine future. It's a closure that will be called
// more once to build new future on each retry.
let routine = move || {
Expand All @@ -212,7 +234,7 @@ impl DapiRequestExecutor for DapiClient {
let _span = tracing::trace_span!(
"execute request",
address = ?address_result,
settings = ?applied_settings,
settings = ?applied_settings_ref,
method = request.method_name(),
)
.entered();
Expand Down Expand Up @@ -242,7 +264,7 @@ impl DapiRequestExecutor for DapiClient {

let mut transport_client = R::Client::with_uri_and_settings(
address.uri().clone(),
&applied_settings,
applied_settings_ref,
&pool,
)
.map_err(|error| ExecutionError {
Expand All @@ -252,7 +274,7 @@ impl DapiRequestExecutor for DapiClient {
})?;

let result = transport_request
.execute_transport(&mut transport_client, &applied_settings)
.execute_transport(&mut transport_client, applied_settings_ref)
.await
.map_err(DapiClientError::Transport);

Expand Down Expand Up @@ -281,7 +303,7 @@ impl DapiRequestExecutor for DapiClient {
update_address_ban_status::<R::Response, DapiClientError>(
&self.address_list,
&execution_result,
&applied_settings,
applied_settings_ref,
);

execution_result
Expand Down
15 changes: 14 additions & 1 deletion packages/rs-dapi-client/src/request_settings.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! DAPI client request settings processing.
use dapi_grpc::tonic::transport::Certificate;
use std::time::Duration;

/// Default low-level client timeout
Expand Down Expand Up @@ -64,12 +65,13 @@ impl RequestSettings {
ban_failed_address: self
.ban_failed_address
.unwrap_or(DEFAULT_BAN_FAILED_ADDRESS),
ca_certificate: None,
}
}
}

/// DAPI settings ready to use.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub struct AppliedRequestSettings {
/// Timeout for establishing a connection.
pub connect_timeout: Option<Duration>,
Expand All @@ -79,4 +81,15 @@ pub struct AppliedRequestSettings {
pub retries: usize,
/// Ban DAPI address if node not responded or responded with error.
pub ban_failed_address: bool,
/// Certificate Authority certificate to use for verifying the server's certificate.
pub ca_certificate: Option<Certificate>,
}
impl AppliedRequestSettings {
/// Use provided CA certificate for verifying the server's certificate.
///
/// If set to None, the system's default CA certificates will be used.
pub fn with_ca_certificate(mut self, ca_cert: Option<Certificate>) -> Self {
self.ca_certificate = ca_cert;
self
}
}
34 changes: 24 additions & 10 deletions packages/rs-dapi-client/src/transport/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{request_settings::AppliedRequestSettings, RequestSettings};
use dapi_grpc::core::v0::core_client::CoreClient;
use dapi_grpc::core::v0::{self as core_proto};
use dapi_grpc::platform::v0::{self as platform_proto, platform_client::PlatformClient};
use dapi_grpc::tonic::transport::{ClientTlsConfig, Uri};
use dapi_grpc::tonic::transport::{Certificate, ClientTlsConfig, Uri};
use dapi_grpc::tonic::Streaming;
use dapi_grpc::tonic::{transport::Channel, IntoRequest};
use futures::{future::BoxFuture, FutureExt, TryFutureExt};
Expand All @@ -22,19 +22,29 @@ fn create_channel(
uri: Uri,
settings: Option<&AppliedRequestSettings>,
) -> Result<Channel, dapi_grpc::tonic::transport::Error> {
let mut builder = Channel::builder(uri).tls_config(
ClientTlsConfig::new()
.with_native_roots()
.with_webpki_roots()
.assume_http2(true),
)?;
let host = uri.host().expect("Failed to get host from URI").to_string();

let mut builder = Channel::builder(uri);
let mut tls_config = ClientTlsConfig::new()
.with_native_roots()
.with_webpki_roots()
.assume_http2(true);

if let Some(settings) = settings {
if let Some(timeout) = settings.connect_timeout {
builder = builder.connect_timeout(timeout);
}

if let Some(pem) = settings.ca_certificate.as_ref() {
let cert = Certificate::from_pem(pem);
tls_config = tls_config.ca_certificate(cert).domain_name(host);
};
}

builder = builder
.tls_config(tls_config)
.expect("Failed to set TLS config");

Ok(builder.connect_lazy())
}

Expand Down Expand Up @@ -256,8 +266,10 @@ impl_transport_request_grpc!(
platform_proto::WaitForStateTransitionResultResponse,
PlatformGrpcClient,
RequestSettings {
timeout: Some(Duration::from_secs(120)),
..RequestSettings::default()
timeout: Some(Duration::from_secs(80)),
retries: Some(0),
ban_failed_address: None,
connect_timeout: None,
},
wait_for_state_transition_result
);
Expand Down Expand Up @@ -487,7 +499,9 @@ impl_transport_request_grpc!(
CoreGrpcClient,
RequestSettings {
timeout: Some(STREAMING_TIMEOUT),
..RequestSettings::default()
ban_failed_address: None,
connect_timeout: None,
retries: None,
},
subscribe_to_transactions_with_proofs
);
Expand Down
1 change: 1 addition & 0 deletions packages/rs-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ drive = { path = "../rs-drive", default-features = false, features = [
drive-proof-verifier = { path = "../rs-drive-proof-verifier" }
dapi-grpc-macros = { path = "../rs-dapi-grpc-macros" }
http = { version = "1.1" }
rustls-pemfile = { version = "2.0.0" }
thiserror = "1.0.64"
tokio = { version = "1.40", features = ["macros", "rt-multi-thread"] }
tokio-util = { version = "0.7.12" }
Expand Down
47 changes: 46 additions & 1 deletion packages/rs-sdk/src/sdk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::platform::{Fetch, Identifier};
use arc_swap::{ArcSwapAny, ArcSwapOption};
use dapi_grpc::mock::Mockable;
use dapi_grpc::platform::v0::{Proof, ResponseMetadata};
use dapi_grpc::tonic::transport::Certificate;
use dpp::bincode;
use dpp::bincode::error::DecodeError;
use dpp::dashcore::Network;
Expand Down Expand Up @@ -750,6 +751,9 @@ pub struct SdkBuilder {

/// Cancellation token; once cancelled, all pending requests should be aborted.
pub(crate) cancel_token: CancellationToken,

/// CA certificate to use for TLS connections.
ca_certificate: Option<Certificate>,
}

impl Default for SdkBuilder {
Expand Down Expand Up @@ -781,6 +785,8 @@ impl Default for SdkBuilder {

version: PlatformVersion::latest(),

ca_certificate: None,

#[cfg(feature = "mocks")]
dump_dir: None,
}
Expand Down Expand Up @@ -830,6 +836,41 @@ impl SdkBuilder {
self
}

/// Configure CA certificate to use when verifying TLS connections.
///
/// Used mainly for testing purposes and local networks.
///
/// If not set, uses standard system CA certificates.
pub fn with_ca_certificate(mut self, pem_certificate: Certificate) -> Self {
self.ca_certificate = Some(pem_certificate);
self
}

/// Load CA certificate from file.
///
/// This is a convenience method that reads the certificate from a file and sets it using
/// [SdkBuilder::with_ca_certificate()].
pub fn with_ca_certificate_file(
self,
certificate_file_path: impl AsRef<std::path::Path>,
) -> std::io::Result<Self> {
let pem = std::fs::read(certificate_file_path)?;

// parse the certificate and check if it's valid
let mut verified_pem = std::io::BufReader::new(pem.as_slice());
rustls_pemfile::certs(&mut verified_pem)
.next()
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No valid certificates found in the file",
)
})??;

let cert = Certificate::from_pem(pem);
Ok(self.with_ca_certificate(cert))
}

/// Configure request settings.
///
/// Tune request settings used to connect to the Dash Platform.
Expand Down Expand Up @@ -962,7 +1003,11 @@ impl SdkBuilder {
let sdk= match self.addresses {
// non-mock mode
Some(addresses) => {
let dapi = DapiClient::new(addresses,dapi_client_settings);
let mut dapi = DapiClient::new(addresses, dapi_client_settings);
if let Some(pem) = self.ca_certificate {
dapi = dapi.with_ca_certificate(pem);
}

#[cfg(feature = "mocks")]
let dapi = dapi.dump_dir(self.dump_dir.clone());

Expand Down
1 change: 1 addition & 0 deletions packages/rs-sdk/tests/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
DASH_SDK_PLATFORM_HOST="127.0.0.1"
DASH_SDK_PLATFORM_PORT=2443
DASH_SDK_PLATFORM_SSL=false
# DASH_SDK_PLATFORM_CA_CERT_PATH=/some/path/to/ca.pem

# ProTxHash of masternode that has at least 1 vote casted for DPNS name `testname`
DASH_SDK_MASTERNODE_OWNER_PRO_REG_TX_HASH="6ac88f64622d9bc0cb79ad0f69657aa9488b213157d20ae0ca371fa5f04fb222"
Expand Down
10 changes: 9 additions & 1 deletion packages/rs-sdk/tests/fetch/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ pub struct Config {
#[serde(default)]
pub platform_ssl: bool,

/// When platform_ssl is true, use the PEM-encoded CA certificate from provided absolute path to verify the server certificate.
#[serde(default)]
pub platform_ca_cert_path: Option<PathBuf>,

/// Directory where all generated test vectors will be saved.
///
/// See [SdkBuilder::with_dump_dir()](crate::SdkBuilder::with_dump_dir()) for more details.
Expand Down Expand Up @@ -193,7 +197,11 @@ impl Config {
&self.core_user,
&self.core_password,
);

if let Some(cert_file) = &self.platform_ca_cert_path {
builder = builder
.with_ca_certificate_file(cert_file)
.expect("load CA cert");
}
#[cfg(feature = "generate-test-vectors")]
let builder = {
// When we use namespaces, clean up the namespaced dump dir before starting
Expand Down
2 changes: 2 additions & 0 deletions packages/rs-sdk/tests/fetch/data_contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use drive_proof_verifier::types::DataContractHistory;
/// Given some dummy data contract ID, when I fetch data contract, I get None because it doesn't exist.
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_data_contract_read_not_found() {
super::common::setup_logs();

pub const DATA_CONTRACT_ID_BYTES: [u8; 32] = [1; 32];
let id = Identifier::from_bytes(&DATA_CONTRACT_ID_BYTES).expect("parse identity id");

Expand Down

0 comments on commit 8185d21

Please sign in to comment.