Skip to content

Commit

Permalink
Merge #325
Browse files Browse the repository at this point in the history
325: remove posgres wire over ws, and refactor r=MarinPostma a=MarinPostma

remove posgres wire porotocol over websocket and refactor to simlify architecture


Co-authored-by: ad hoc <[email protected]>
  • Loading branch information
bors[bot] and MarinPostma authored Apr 8, 2023
2 parents b6732fb + c93f564 commit bab0c66
Show file tree
Hide file tree
Showing 19 changed files with 143 additions and 1,034 deletions.
24 changes: 24 additions & 0 deletions libsql-server/sqld/src/database/factory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::sync::Arc;

use futures::Future;

use super::Database;
use crate::error::Error;

#[async_trait::async_trait]
pub trait DbFactory: Send + Sync {
async fn create(&self) -> Result<Arc<dyn Database>, Error>;
}

#[async_trait::async_trait]
impl<F, DB, Fut> DbFactory for F
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<DB, Error>> + Send,
DB: Database + Sync + Send + 'static,
{
async fn create(&self) -> Result<Arc<dyn Database>, Error> {
let db = (self)().await?;
Ok(Arc::new(db))
}
}
2 changes: 1 addition & 1 deletion libsql-server/sqld/src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::query_analysis::{State, Statement};
use crate::Result;

pub mod dump_loader;
pub mod factory;
pub mod libsql;
pub mod service;
pub mod write_proxy;

const TXN_TIMEOUT_SECS: u64 = 5;
Expand Down
83 changes: 0 additions & 83 deletions libsql-server/sqld/src/database/service.rs

This file was deleted.

2 changes: 1 addition & 1 deletion libsql-server/sqld/src/database/write_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::stats::Stats;
use crate::Result;

use super::Program;
use super::{libsql::LibSqlDb, service::DbFactory, Database};
use super::{factory::DbFactory, libsql::LibSqlDb, Database};

#[derive(Clone)]
pub struct WriteProxyDbFactory {
Expand Down
2 changes: 1 addition & 1 deletion libsql-server/sqld/src/hrana/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::auth::Auth;
use crate::database::service::DbFactory;
use crate::database::factory::DbFactory;
use crate::utils::services::idle_shutdown::IdleKicker;
use anyhow::{Context as _, Result};
use enclose::enclose;
Expand Down
2 changes: 1 addition & 1 deletion libsql-server/sqld/src/http/hrana_over_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::future::Future;
use std::sync::Arc;

use crate::database::service::DbFactory;
use crate::database::factory::DbFactory;
use crate::database::Database;
use crate::hrana;

Expand Down
63 changes: 10 additions & 53 deletions libsql-server/sqld/src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ mod hrana_over_http;
mod stats;
mod types;

use std::future::poll_fn;
use std::net::SocketAddr;
use std::sync::Arc;

Expand All @@ -16,15 +15,13 @@ use serde::Serialize;
use serde_json::{json, Number};
use tokio::sync::{mpsc, oneshot};
use tonic::codegen::http;
use tower::balance::pool;
use tower::load::Load;
use tower::{BoxError, MakeService, Service, ServiceBuilder};
use tower::ServiceBuilder;
use tower_http::trace::DefaultOnResponse;
use tower_http::{compression::CompressionLayer, cors};
use tracing::{Level, Span};

use crate::auth::Auth;
use crate::database::service::DbFactory;
use crate::database::factory::DbFactory;
use crate::error::Error;
use crate::hrana;
use crate::http::types::HttpQuery;
Expand Down Expand Up @@ -152,12 +149,6 @@ fn parse_queries(queries: Vec<QueryObject>) -> anyhow::Result<Vec<Query>> {
Ok(out)
}

/// Internal Message used to communicate between the HTTP service
struct Message {
batch: Vec<Query>,
resp: oneshot::Sender<Result<Vec<Option<QueryResult>>, BoxError>>,
}

fn parse_payload(data: &[u8]) -> Result<HttpQuery, Response<Body>> {
match serde_json::from_slice(data) {
Ok(data) => Ok(data),
Expand All @@ -167,34 +158,29 @@ fn parse_payload(data: &[u8]) -> Result<HttpQuery, Response<Body>> {

async fn handle_query(
mut req: Request<Body>,
sender: mpsc::Sender<Message>,
db_factory: Arc<dyn DbFactory>,
) -> anyhow::Result<Response<Body>> {
let bytes = to_bytes(req.body_mut()).await?;
let req = match parse_payload(&bytes) {
Ok(req) => req,
Err(resp) => return Ok(resp),
};

let (s, resp) = oneshot::channel();

let batch = match parse_queries(req.statements) {
Ok(queries) => queries,
Err(e) => return Ok(error(&e.to_string(), StatusCode::BAD_REQUEST)),
};

let msg = Message { batch, resp: s };
let _ = sender.send(msg).await;

let result = resp.await;
let db = db_factory.create().await?;

match result {
Ok(Ok(rows)) => {
match db.execute_batch_or_rollback(batch).await {
Ok((rows, _)) => {
let json = query_response_to_json(rows)?;
Ok(Response::builder()
.header("Content-Type", "application/json")
.body(Body::from(json))?)
}
Err(_) | Ok(Err(_)) => Ok(error("internal error", StatusCode::INTERNAL_SERVER_ERROR)),
Err(_) => Ok(error("internal error", StatusCode::INTERNAL_SERVER_ERROR)),
}
}

Expand Down Expand Up @@ -234,7 +220,6 @@ async fn handle_upgrade(
async fn handle_request(
auth: Arc<Auth>,
req: Request<Body>,
sender: mpsc::Sender<Message>,
upgrade_tx: mpsc::Sender<hrana::Upgrade>,
db_factory: Arc<dyn DbFactory>,
enable_console: bool,
Expand All @@ -253,7 +238,7 @@ async fn handle_request(
}

match (req.method(), req.uri().path()) {
(&Method::POST, "/") => handle_query(req, sender).await,
(&Method::POST, "/") => handle_query(req, db_factory.clone()).await,
(&Method::GET, "/version") => Ok(handle_version()),
(&Method::GET, "/console") if enable_console => show_console().await,
(&Method::GET, "/health") => Ok(handle_health()),
Expand All @@ -272,29 +257,17 @@ fn handle_version() -> Response<Body> {

// TODO: refactor
#[allow(clippy::too_many_arguments)]
pub async fn run_http<F>(
pub async fn run_http(
addr: SocketAddr,
auth: Arc<Auth>,
db_factory_service: F,
db_factory: Arc<dyn DbFactory>,
upgrade_tx: mpsc::Sender<hrana::Upgrade>,
enable_console: bool,
idle_shutdown_layer: Option<IdleShutdownLayer>,
stats: Stats,
) -> anyhow::Result<()>
where
F: MakeService<(), Vec<Query>> + Send + 'static,
F::Service: Load + Service<Vec<Query>, Response = Vec<Option<QueryResult>>, Error = Error>,
<F::Service as Load>::Metric: std::fmt::Debug,
F::MakeError: Into<BoxError>,
F::Error: Into<BoxError>,
<F as MakeService<(), Vec<Query>>>::Service: Send,
<F as MakeService<(), Vec<Query>>>::Future: Send,
<<F as MakeService<(), Vec<Query>>>::Service as Service<Vec<Query>>>::Future: Send,
{
) -> anyhow::Result<()> {
tracing::info!("listening for HTTP requests on {addr}");

let (sender, mut receiver) = mpsc::channel(1024);
fn trace_request<B>(req: &Request<B>, _span: &Span) {
tracing::info!("got request: {} {}", req.method(), req.uri());
}
Expand All @@ -320,7 +293,6 @@ where
handle_request(
auth.clone(),
req,
sender.clone(),
upgrade_tx.clone(),
db_factory.clone(),
enable_console,
Expand All @@ -330,21 +302,6 @@ where

let server = hyper::server::Server::bind(&addr).serve(tower::make::Shared::new(service));

tokio::spawn(async move {
let mut pool = pool::Builder::new().build(db_factory_service, ());
while let Some(Message { batch, resp }) = receiver.recv().await {
if let Err(e) = poll_fn(|c| pool.poll_ready(c)).await {
tracing::error!("Connection pool error: {e}");
continue;
}

let fut = pool.call(batch);
tokio::spawn(async move {
let _ = resp.send(fut.await);
});
}
});

server.await.context("Http server exited with an error")?;

Ok(())
Expand Down
Loading

0 comments on commit bab0c66

Please sign in to comment.