Skip to content

Commit

Permalink
session based handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
MarinPostma committed Oct 24, 2023
1 parent c05297b commit 33b8ee6
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 57 deletions.
4 changes: 4 additions & 0 deletions libsql-replication/proto/replication_log.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ message LogOffset {
message HelloRequest {}

message HelloResponse {
/// id of the replicated log
string log_id = 3;
/// string-encoded Uuid v4 token for the current session, changes on each restart, and must be passed in subsequent requests header.string
/// If the header session token fails to match the current session token, a NO_HELLO error is returned
bytes session_token = 4;
}

message Frame {
Expand Down
16 changes: 15 additions & 1 deletion libsql-replication/src/replicator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub enum Error {
NeedSnapshot,
#[error("Replication meta error: {0}")]
Meta(#[from] super::meta::Error),
#[error("Hanshake required")]
NoHandshake,
}

impl From<tokio::task::JoinError> for Error {
Expand Down Expand Up @@ -182,7 +184,19 @@ impl<C: ReplicatorClient> Replicator<C> {
tracing::debug!("loading snapshot");
// remove any outstanding frames in the buffer that are not part of a
// transaction: they are now part of the snapshot.
self.load_snapshot().await?;
match self.load_snapshot().await {
Ok(()) => (),
Err(Error::NoHandshake) => {

Check failure on line 189 in libsql-replication/src/replicator.rs

View workflow job for this annotation

GitHub Actions / Run Checks

Diff in /home/runner/work/libsql/libsql/libsql-replication/src/replicator.rs
self.has_handshake = false;
self.try_perform_handshake().await?;
},
Err(e) => return Err(e),
}
}
Some(Err(Error::NoHandshake)) => {
tracing::debug!("session expired, new handshake required");
self.has_handshake = false;
self.try_perform_handshake().await?;
}
Some(Err(e)) => return Err(e),
None => return Ok(()),
Expand Down
13 changes: 13 additions & 0 deletions libsql-replication/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,21 @@ pub mod proxy {

pub mod replication {
#![allow(clippy::all)]

use uuid::Uuid;
tonic::include_proto!("wal_log");

pub const NO_HELLO_ERROR_MSG: &str = "NO_HELLO";
pub const NEED_SNAPSHOT_ERROR_MSG: &str = "NEED_SNAPSHOT";

pub const SESSION_TOKEN_KEY: &str = "x-session-token";
pub const NAMESPACE_METADATA_KEY: &str = "x-namespace-bin";

Check failure on line 16 in libsql-replication/src/rpc.rs

View workflow job for this annotation

GitHub Actions / Run Checks

Diff in /home/runner/work/libsql/libsql/libsql-replication/src/rpc.rs

// Verify that the session token is valid
pub fn verify_session_token(token: &[u8]) -> Result<(), Box<dyn std::error::Error + Sync + Send + 'static>> {
let s = std::str::from_utf8(token)?;
s.parse::<Uuid>()?;

Ok(())
}
}
3 changes: 2 additions & 1 deletion libsql-server/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use anyhow::{bail, Context as _, Result};
use axum::http::HeaderValue;
use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY;
use tonic::Status;

use crate::{namespace::NamespaceName, rpc::NAMESPACE_METADATA_KEY};
use crate::namespace::NamespaceName;

static GRPC_AUTH_HEADER: &str = "x-authorization";
static GRPC_PROXY_AUTH_HEADER: &str = "x-proxy-authorization";
Expand Down
2 changes: 1 addition & 1 deletion libsql-server/src/connection/write_proxy.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::path::PathBuf;
use std::sync::Arc;

use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY;
use parking_lot::Mutex as PMutex;
use rusqlite::types::ValueRef;
use sqld_libsql_bindings::wal_hook::{TransparentMethods, TRANSPARENT_METHODS};
Expand All @@ -23,7 +24,6 @@ use crate::replication::FrameNo;
use crate::rpc::proxy::rpc::proxy_client::ProxyClient;
use crate::rpc::proxy::rpc::query_result::RowResult;
use crate::rpc::proxy::rpc::{DisconnectMessage, ExecuteResults};
use crate::rpc::NAMESPACE_METADATA_KEY;
use crate::stats::Stats;
use crate::{Result, DEFAULT_AUTO_CHECKPOINT};

Expand Down
5 changes: 5 additions & 0 deletions libsql-server/src/namespace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,11 @@ impl Namespace<ReplicaDatabase> {
| Error::NeedSnapshot) => {
tracing::warn!("non-fatal replication error, retrying from last commit index: {e}");
},
Error::NoHandshake => {
// not strictly necessary, but in case the handshake error goes uncaught,
// we reset the client state.
replicator.client_mut().reset_token();
}
}
}
});
Expand Down
30 changes: 26 additions & 4 deletions libsql-server/src/replication/replicator_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@ use libsql_replication::frame::Frame;
use libsql_replication::meta::WalIndexMeta;
use libsql_replication::replicator::{map_frame_err, Error, ReplicatorClient};
use libsql_replication::rpc::replication::replication_log_client::ReplicationLogClient;

Check failure on line 7 in libsql-server/src/replication/replicator_client.rs

View workflow job for this annotation

GitHub Actions / Run Checks

Diff in /home/runner/work/libsql/libsql/libsql-server/src/replication/replicator_client.rs
use libsql_replication::rpc::replication::{HelloRequest, LogOffset};
use libsql_replication::rpc::replication::{HelloRequest, LogOffset, NAMESPACE_METADATA_KEY, SESSION_TOKEN_KEY, NO_HELLO_ERROR_MSG, verify_session_token};
use tokio::sync::watch;
use tokio_stream::{Stream, StreamExt};
use tonic::metadata::BinaryMetadataValue;
use tonic::metadata::{BinaryMetadataValue, AsciiMetadataValue};
use tonic::transport::Channel;
use tonic::{Code, Request};
use bytes::Bytes;

use crate::namespace::NamespaceName;
use crate::replication::FrameNo;
use crate::rpc::{NAMESPACE_DOESNT_EXIST, NAMESPACE_METADATA_KEY};
use crate::rpc::NAMESPACE_DOESNT_EXIST;

pub struct Client {
client: ReplicationLogClient<Channel>,
meta: WalIndexMeta,
pub current_frame_no_notifier: watch::Sender<Option<FrameNo>>,
namespace: NamespaceName,
session_token: Option<Bytes>,
}

impl Client {
Expand All @@ -37,6 +39,7 @@ impl Client {
client,
current_frame_no_notifier,
meta,
session_token: None,
})
}

Expand All @@ -47,6 +50,11 @@ impl Client {
BinaryMetadataValue::from_bytes(self.namespace.as_slice()),
);

Check failure on line 52 in libsql-server/src/replication/replicator_client.rs

View workflow job for this annotation

GitHub Actions / Run Checks

Diff in /home/runner/work/libsql/libsql/libsql-server/src/replication/replicator_client.rs
if let Some(token) = self.session_token.clone() {
// SAFETY: we always check the session token
req.metadata_mut().insert(SESSION_TOKEN_KEY, unsafe { AsciiMetadataValue::from_shared_unchecked(token) });
}

req
}

Expand All @@ -56,6 +64,10 @@ impl Client {
None => 0,
}
}

pub(crate) fn reset_token(&mut self) {
self.session_token = None;
}
}

#[derive(Debug, thiserror::Error)]
Expand All @@ -72,6 +84,8 @@ impl ReplicatorClient for Client {
match self.client.hello(req).await {
Ok(resp) => {
let hello = resp.into_inner();
verify_session_token(&hello.session_token).map_err(Error::Client)?;
self.session_token.replace(hello.session_token.clone());
self.meta.merge_hello(hello)?;
self.current_frame_no_notifier
.send_replace(self.meta.current_frame_no());
Expand Down Expand Up @@ -113,7 +127,7 @@ impl ReplicatorClient for Client {
.client
.snapshot(req)
.await
.map_err(|e| Error::Client(e.into()))?
.map_err(map_status)?
.into_inner()
.map(map_frame_err);
Ok(Box::pin(stream))
Expand All @@ -133,3 +147,11 @@ impl ReplicatorClient for Client {
self.meta.current_frame_no()
}
}

fn map_status(status: tonic::Status) -> Error {
if status.code() == Code::FailedPrecondition && status.message() == NO_HELLO_ERROR_MSG {
Error::NoHandshake
} else {
Error::Client(status.into())
}
}
2 changes: 1 addition & 1 deletion libsql-server/src/rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

use anyhow::Context;
use hyper_rustls::TlsAcceptor;
use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY;
use rustls::server::AllowAnyAuthenticatedClient;
use rustls::RootCertStore;
use tonic::Status;
Expand All @@ -22,7 +23,6 @@ pub mod replication_log_proxy;

/// A tonic error code to signify that a namespace doesn't exist.
pub const NAMESPACE_DOESNT_EXIST: &str = "NAMESPACE_DOESNT_EXIST";
pub(crate) const NAMESPACE_METADATA_KEY: &str = "x-namespace-bin";

pub async fn run_rpc_server<A: crate::net::Accept>(
proxy_service: ProxyService,
Expand Down
61 changes: 22 additions & 39 deletions libsql-server/src/rpc/replication_log.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use std::collections::HashSet;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::sync::Arc;

use bytes::Bytes;
use futures::stream::BoxStream;
pub use libsql_replication::rpc::replication as rpc;
use libsql_replication::rpc::replication::replication_log_server::ReplicationLog;
use libsql_replication::rpc::replication::{
Frame, Frames, HelloRequest, HelloResponse, LogOffset, NEED_SNAPSHOT_ERROR_MSG,
NO_HELLO_ERROR_MSG,
NO_HELLO_ERROR_MSG, SESSION_TOKEN_KEY,
};
use tokio_stream::StreamExt;
use tonic::Status;
use uuid::Uuid;

use crate::auth::Auth;
use crate::namespace::{NamespaceName, NamespaceStore, PrimaryNamespaceMaker};
use crate::namespace::{NamespaceStore, PrimaryNamespaceMaker};
use crate::replication::primary::frame_stream::FrameStream;
use crate::replication::LogReadError;
use crate::utils::services::idle_shutdown::IdleShutdownKicker;
Expand All @@ -23,10 +23,10 @@ use super::NAMESPACE_DOESNT_EXIST;

pub struct ReplicationLogService {
namespaces: NamespaceStore<PrimaryNamespaceMaker>,
replicas_with_hello: RwLock<HashSet<(SocketAddr, NamespaceName)>>,
idle_shutdown_layer: Option<IdleShutdownKicker>,
auth: Option<Arc<Auth>>,
disable_namespaces: bool,
session_token: Bytes,
}

pub const MAX_FRAMES_PER_BATCH: usize = 1024;
Expand All @@ -38,9 +38,10 @@ impl ReplicationLogService {
auth: Option<Arc<Auth>>,
disable_namespaces: bool,
) -> Self {
let session_token = Uuid::new_v4().to_string().into();
Self {
namespaces,
replicas_with_hello: Default::default(),
session_token,
idle_shutdown_layer,
auth,
disable_namespaces,
Expand All @@ -54,6 +55,16 @@ impl ReplicationLogService {

Ok(())
}

Check failure on line 57 in libsql-server/src/rpc/replication_log.rs

View workflow job for this annotation

GitHub Actions / Run Checks

Diff in /home/runner/work/libsql/libsql/libsql-server/src/rpc/replication_log.rs

fn verify_session_token<R>(&self, req: &tonic::Request<R>) -> Result<(), Status> {
let no_hello = || { Err(Status::failed_precondition(NO_HELLO_ERROR_MSG)) };
let Some(token) = req.metadata().get(SESSION_TOKEN_KEY) else { return no_hello() };
if token.as_bytes() != self.session_token {
return no_hello();
}

Ok(())
}
}

fn map_frame_stream_output(
Expand Down Expand Up @@ -122,18 +133,11 @@ impl ReplicationLog for ReplicationLogService {
req: tonic::Request<LogOffset>,
) -> Result<tonic::Response<Self::LogEntriesStream>, Status> {
self.authenticate(&req)?;
self.verify_session_token(&req)?;

let namespace = super::extract_namespace(self.disable_namespaces, &req)?;

let replica_addr = req
.remote_addr()
.ok_or(Status::internal("No remote RPC address"))?;
let req = req.into_inner();
{
let guard = self.replicas_with_hello.read().unwrap();
if !guard.contains(&(replica_addr, namespace.clone())) {
return Err(Status::failed_precondition(NO_HELLO_ERROR_MSG));
}
}

let logger = self
.namespaces
Expand Down Expand Up @@ -162,19 +166,10 @@ impl ReplicationLog for ReplicationLogService {
req: tonic::Request<LogOffset>,
) -> Result<tonic::Response<Frames>, Status> {
self.authenticate(&req)?;
self.verify_session_token(&req)?;
let namespace = super::extract_namespace(self.disable_namespaces, &req)?;

let replica_addr = req
.remote_addr()
.ok_or(Status::internal("No remote RPC address"))?;
let req = req.into_inner();
{
let guard = self.replicas_with_hello.read().unwrap();
if !guard.contains(&(replica_addr, namespace.clone())) {
return Err(Status::failed_precondition(NO_HELLO_ERROR_MSG));
}
}

let logger = self
.namespaces
.with(namespace, |ns| ns.db.logger.clone())
Expand Down Expand Up @@ -206,19 +201,6 @@ impl ReplicationLog for ReplicationLogService {
self.authenticate(&req)?;
let namespace = super::extract_namespace(self.disable_namespaces, &req)?;


use tonic::transport::server::TcpConnectInfo;

req.extensions().get::<TcpConnectInfo>().unwrap();
let replica_addr = req
.remote_addr()
.ok_or(Status::internal("No remote RPC address"))?;

{
let mut guard = self.replicas_with_hello.write().unwrap();
guard.insert((replica_addr, namespace.clone()));
}

let logger = self
.namespaces
.with(namespace, |ns| ns.db.logger.clone())
Expand All @@ -233,6 +215,7 @@ impl ReplicationLog for ReplicationLogService {

let response = HelloResponse {
log_id: logger.log_id().to_string(),
session_token: self.session_token.clone(),
};

Ok(tonic::Response::new(response))
Expand Down
6 changes: 3 additions & 3 deletions libsql-server/tests/embedded_replica/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ fn embedded_replica() {
//
// This does change the serialization mode for sqld but because the mode
// that we use in libsql is safer than the sqld one it is still safe.
// let db = Database::open_in_memory().unwrap();
// db.connect().unwrap();
let db = Database::open_in_memory().unwrap();
db.connect().unwrap();

make_primary(&mut sim, tmp_host_path.clone());

Expand Down Expand Up @@ -102,7 +102,7 @@ fn embedded_replica() {
.await?;

let n = db.sync().await?;
assert_eq!(n, Some(2));
assert_eq!(n, Some(1));

let err = conn
.execute("INSERT INTO user(id) VALUES (1), (1)", ())
Expand Down
2 changes: 1 addition & 1 deletion libsql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
//! let mut db = Database::open_with_local_sync("/tmp/test.db").await.unwrap();
//!
//! let frames = Frames::Vec(vec![]);
//! db.sync_frames(frames).unwrap();
//! db.sync_frames(frames).await.unwrap();
//! let conn = db.connect().unwrap();
//! conn.execute("SELECT * FROM users", ()).await.unwrap();
//! # }
Expand Down
2 changes: 1 addition & 1 deletion libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl Connection {
};
match err {
ffi::SQLITE_OK => {}
e => {
_ => {
return Err(Error::ConnectionFailed(db_path));
}
}
Expand Down
Loading

0 comments on commit 33b8ee6

Please sign in to comment.