From 1e60b62c24dd67c8d61c0e5260f8ba0f8a6f842e Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Wed, 14 Aug 2024 18:54:23 -0400 Subject: [PATCH] statement stream --- diesel-wasm-sqlite/Cargo.lock | 2 +- diesel-wasm-sqlite/src/connection/mod.rs | 27 ++- .../src/connection/sqlite_value.rs | 2 +- .../src/connection/statement_iterator.rs | 172 -------------- .../src/connection/statement_stream.rs | 212 ++++++++++++++++++ diesel-wasm-sqlite/src/connection/stmt.rs | 2 +- diesel-wasm-sqlite/src/query_builder/mod.rs | 2 +- 7 files changed, 232 insertions(+), 187 deletions(-) delete mode 100644 diesel-wasm-sqlite/src/connection/statement_iterator.rs create mode 100644 diesel-wasm-sqlite/src/connection/statement_stream.rs diff --git a/diesel-wasm-sqlite/Cargo.lock b/diesel-wasm-sqlite/Cargo.lock index aadc60362..f179de053 100644 --- a/diesel-wasm-sqlite/Cargo.lock +++ b/diesel-wasm-sqlite/Cargo.lock @@ -142,7 +142,7 @@ dependencies = [ [[package]] name = "diesel-async" version = "0.5.0" -source = "git+https://github.com/insipx/diesel_async?branch=insipx/make-stmt-cache-public#86a24a38d9d841ef9e92022cd983bbd700286397" +source = "git+https://github.com/insipx/diesel_async?branch=insipx/make-stmt-cache-public#f1c4838ae6d7951b78572c249ba65f7a107488a0" dependencies = [ "async-trait", "diesel", diff --git a/diesel-wasm-sqlite/src/connection/mod.rs b/diesel-wasm-sqlite/src/connection/mod.rs index fac2bbe8b..7e09fea8e 100644 --- a/diesel-wasm-sqlite/src/connection/mod.rs +++ b/diesel-wasm-sqlite/src/connection/mod.rs @@ -5,7 +5,7 @@ mod raw; mod row; // mod serialized_database; mod sqlite_value; -// mod statement_iterator; +mod statement_stream; mod stmt; pub(crate) use self::bind_collector::SqliteBindCollector; @@ -17,14 +17,16 @@ use self::raw::RawConnection; // use self::statement_iterator::*; use self::stmt::{Statement, StatementUse}; use crate::query_builder::*; -use diesel::{connection::{statement_cache::StatementCacheKey, DefaultLoadingMode, LoadConnection}, deserialize::{FromSqlRow, StaticallySizedRow}, expression::QueryMetadata, query_builder::QueryBuilder as _, result::*, serialize::ToSql, sql_types::HasSqlType}; -use futures::{FutureExt, TryFutureExt}; +use diesel::{connection::{statement_cache::StatementCacheKey}, query_builder::QueryBuilder as _, result::*}; +use futures::future::LocalBoxFuture; +use futures::stream::LocalBoxStream; +use futures::FutureExt; +use statement_stream::PrivateStatementStream; use std::sync::{Arc, Mutex}; -use diesel::{connection::{ConnectionSealed, Instrumentation, WithMetadataLookup}, query_builder::{AsQuery, QueryFragment, QueryId}, sql_types::TypeMetadata, QueryResult}; +use diesel::{connection::{ConnectionSealed, Instrumentation}, query_builder::{AsQuery, QueryFragment, QueryId}, QueryResult}; pub use diesel_async::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection, TransactionManager, stmt_cache::StmtCache}; -use futures::{future::BoxFuture, stream::BoxStream}; use row::SqliteRow; use crate::{get_sqlite_unchecked, WasmSqlite, WasmSqliteError}; @@ -69,21 +71,24 @@ impl SimpleAsyncConnection for WasmSqliteConnection { impl AsyncConnection for WasmSqliteConnection { type Backend = WasmSqlite; type TransactionManager = AnsiTransactionManager; - type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; - type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; - type Stream<'conn, 'query> = BoxStream<'static, QueryResult>>; + type ExecuteFuture<'conn, 'query> = LocalBoxFuture<'query, QueryResult>; + type LoadFuture<'conn, 'query> = LocalBoxFuture<'query, QueryResult>>; + type Stream<'conn, 'query> = LocalBoxStream<'query, QueryResult>>; type Row<'conn, 'query> = SqliteRow<'conn, 'query>; async fn establish(database_url: &str) -> diesel::prelude::ConnectionResult { WasmSqliteConnection::establish_inner(database_url).await } - fn load<'conn, 'query, T>(&'conn mut self, _source: T) -> Self::LoadFuture<'conn, 'query> + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where T: AsQuery + 'query, T::Query: QueryFragment + QueryId + 'query, { - todo!() + async { + let statement = self.prepared_query(source.as_query()).await?; + Ok(PrivateStatementStream::new(statement).stream()) + }.boxed_local() } fn execute_returning_count<'conn, 'query, T>( @@ -275,7 +280,7 @@ impl WasmSqliteConnection { } async fn establish_inner(database_url: &str) -> Result { - use diesel::result::ConnectionError::CouldntSetupConfiguration; + // use diesel::result::ConnectionError::CouldntSetupConfiguration; let raw_connection = RawConnection::establish(database_url).await.unwrap(); let sqlite3 = crate::get_sqlite().await; diff --git a/diesel-wasm-sqlite/src/connection/sqlite_value.rs b/diesel-wasm-sqlite/src/connection/sqlite_value.rs index e94e77217..ca47efa95 100644 --- a/diesel-wasm-sqlite/src/connection/sqlite_value.rs +++ b/diesel-wasm-sqlite/src/connection/sqlite_value.rs @@ -2,7 +2,7 @@ use std::cell::Ref; -use crate::ffi::{self, SQLiteCompatibleType}; +use crate::ffi::SQLiteCompatibleType; use crate::{backend::SqliteType, sqlite_types}; use wasm_bindgen::JsValue; diff --git a/diesel-wasm-sqlite/src/connection/statement_iterator.rs b/diesel-wasm-sqlite/src/connection/statement_iterator.rs deleted file mode 100644 index 393ec9e47..000000000 --- a/diesel-wasm-sqlite/src/connection/statement_iterator.rs +++ /dev/null @@ -1,172 +0,0 @@ -use std::cell::RefCell; -use std::rc::Rc; - -use super::row::{PrivateSqliteRow, SqliteRow}; -use super::stmt::StatementUse; -use crate::result::QueryResult; - -#[allow(missing_debug_implementations)] -pub struct StatementIterator<'stmt, 'query> { - inner: PrivateStatementIterator<'stmt, 'query>, - column_names: Option]>>, - field_count: usize, -} - -impl<'stmt, 'query> StatementIterator<'stmt, 'query> { - #[cold] - #[allow(unsafe_code)] // call to unsafe function - fn handle_duplicated_row_case( - outer_last_row: &mut Rc>>, - column_names: &mut Option]>>, - field_count: usize, - ) -> Option>> { - // We don't own the statement. There is another existing reference, likely because - // a user stored the row in some long time container before calling next another time - // In this case we copy out the current values into a temporary store and advance - // the statement iterator internally afterwards - let last_row = { - let mut last_row = match outer_last_row.try_borrow_mut() { - Ok(o) => o, - Err(_e) => { - return Some(Err(crate::result::Error::DeserializationError( - "Failed to reborrow row. Try to release any `SqliteField` or `SqliteValue` \ - that exists at this point" - .into(), - ))); - } - }; - let last_row = &mut *last_row; - let duplicated = last_row.duplicate(column_names); - std::mem::replace(last_row, duplicated) - }; - if let PrivateSqliteRow::Direct(mut stmt) = last_row { - let res = unsafe { - // This is actually safe here as we've already - // performed one step. For the first step we would have - // used `PrivateStatementIterator::NotStarted` where we don't - // have access to `PrivateSqliteRow` at all - stmt.step(false) - }; - *outer_last_row = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); - match res { - Err(e) => Some(Err(e)), - Ok(false) => None, - Ok(true) => Some(Ok(SqliteRow { - inner: Rc::clone(outer_last_row), - field_count, - })), - } - } else { - // any other state than `PrivateSqliteRow::Direct` is invalid here - // and should not happen. If this ever happens this is a logic error - // in the code above - unreachable!( - "You've reached an impossible internal state. \ - If you ever see this error message please open \ - an issue at https://github.com/diesel-rs/diesel \ - providing example code how to trigger this error." - ) - } - } -} - -enum PrivateStatementIterator<'stmt, 'query> { - NotStarted(Option>), - Started(Rc>>), -} - -impl<'stmt, 'query> StatementIterator<'stmt, 'query> { - pub fn new(stmt: StatementUse<'stmt, 'query>) -> StatementIterator<'stmt, 'query> { - Self { - inner: PrivateStatementIterator::NotStarted(Some(stmt)), - column_names: None, - field_count: 0, - } - } -} - -impl<'stmt, 'query> Iterator for StatementIterator<'stmt, 'query> { - type Item = QueryResult>; - - #[allow(unsafe_code)] // call to unsafe function - fn next(&mut self) -> Option { - use PrivateStatementIterator::{NotStarted, Started}; - match &mut self.inner { - NotStarted(ref mut stmt @ Some(_)) => { - let mut stmt = stmt - .take() - .expect("It must be there because we checked that above"); - let step = unsafe { - // This is safe as we pass `first_step = true` to reset the cached column names - stmt.step(true) - }; - match step { - Err(e) => Some(Err(e)), - Ok(false) => None, - Ok(true) => { - let field_count = stmt.column_count() as usize; - self.field_count = field_count; - let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); - self.inner = Started(inner.clone()); - Some(Ok(SqliteRow { inner, field_count })) - } - } - } - Started(ref mut last_row) => { - // There was already at least one iteration step - // We check here if the caller already released the row value or not - // by checking if our Rc owns the data or not - if let Some(last_row_ref) = Rc::get_mut(last_row) { - // We own the statement, there is no other reference here. - // This means we don't need to copy out values from the sqlite provided - // datastructures for now - // We don't need to use the runtime borrowing system of the RefCell here - // as we have a mutable reference, so all of this below is checked at compile time - if let PrivateSqliteRow::Direct(ref mut stmt) = last_row_ref.get_mut() { - let step = unsafe { - // This is actually safe here as we've already - // performed one step. For the first step we would have - // used `PrivateStatementIterator::NotStarted` where we don't - // have access to `PrivateSqliteRow` at all - - stmt.step(false) - }; - match step { - Err(e) => Some(Err(e)), - Ok(false) => None, - Ok(true) => { - let field_count = self.field_count; - Some(Ok(SqliteRow { - inner: Rc::clone(last_row), - field_count, - })) - } - } - } else { - // any other state than `PrivateSqliteRow::Direct` is invalid here - // and should not happen. If this ever happens this is a logic error - // in the code above - unreachable!( - "You've reached an impossible internal state. \ - If you ever see this error message please open \ - an issue at https://github.com/diesel-rs/diesel \ - providing example code how to trigger this error." - ) - } - } else { - Self::handle_duplicated_row_case( - last_row, - &mut self.column_names, - self.field_count, - ) - } - } - NotStarted(_s) => { - // we likely got an error while executing the other - // `NotStarted` branch above. In this case we just want to stop - // iterating here - None - } - } - } -} diff --git a/diesel-wasm-sqlite/src/connection/statement_stream.rs b/diesel-wasm-sqlite/src/connection/statement_stream.rs new file mode 100644 index 000000000..5349fc396 --- /dev/null +++ b/diesel-wasm-sqlite/src/connection/statement_stream.rs @@ -0,0 +1,212 @@ +use std::cell::RefCell; +use std::rc::Rc; + +use super::row::{PrivateSqliteRow, SqliteRow}; +use super::stmt::StatementUse; +use diesel::result::QueryResult; +use futures::stream::LocalBoxStream; +use futures::{Stream, TryStreamExt}; + +pub struct StatementStream<'stmt, 'quer> { + stream: LocalBoxStream<'query, QueryResult>>, +} + +impl<'stmt, 'query> Stream for StatementStream<'stmt, 'query> { + type Item = QueryResult>; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let stream = &mut self.stream; + stream.try_poll_next_unpin(cx) + } +} + +#[allow(missing_debug_implementations)] +pub struct PrivateStatementStream<'stmt, 'query> { + inner: StatementStreamState<'stmt, 'query>, + column_names: Option]>>, + field_count: usize, +} + +impl<'stmt, 'query> PrivateStatementStream<'stmt, 'query> { + #[cold] + async fn handle_duplicated_row_case( + outer_last_row: &mut Rc>>, + column_names: &mut Option]>>, + field_count: usize, + ) -> Option>> { + // We don't own the statement. There is another existing reference, likely because + // a user stored the row in some long time container before calling next another time + // In this case we copy out the current values into a temporary store and advance + // the statement iterator internally afterwards + let last_row = { + let mut last_row = match outer_last_row.try_borrow_mut() { + Ok(o) => o, + Err(_e) => { + return Some(Err(diesel::result::Error::DeserializationError( + "Failed to reborrow row. Try to release any `SqliteField` or `SqliteValue` \ + that exists at this point" + .into(), + ))); + } + }; + let last_row = &mut *last_row; + let duplicated = last_row.duplicate(column_names); + std::mem::replace(last_row, duplicated) + }; + if let PrivateSqliteRow::Direct(mut stmt) = last_row { + let res = stmt.step(false).await; + *outer_last_row = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); + match res { + Err(e) => Some(Err(e)), + Ok(false) => None, + Ok(true) => Some(Ok(SqliteRow { + inner: Rc::clone(outer_last_row), + field_count, + })), + } + } else { + // any other state than `PrivateSqliteRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + } +} + +enum StatementStreamState<'stmt, 'query> { + NotStarted(Option>), + Started(Rc>>), +} + +impl<'stmt, 'query> PrivateStatementStream<'stmt, 'query> { + pub fn new(stmt: StatementUse<'stmt, 'query>) -> PrivateStatementStream<'stmt, 'query> { + Self { + inner: StatementStreamState::NotStarted(Some(stmt)), + column_names: None, + field_count: 0, + } + } +} +/// Rolling a custom `Stream` impl on PrivateStatementStream was taking too long/tricky +/// so using `futures::unfold`. Rolling a custom `Stream` would probably be better, +/// but performance wise/code-readability sense not very different +impl<'stmt, 'query> PrivateStatementStream<'stmt, 'query> { + pub fn stream(self) -> LocalBoxStream<'query, QueryResult>> { + use StatementStreamState::{NotStarted, Started}; + let stream = futures::stream::unfold(self, |mut statement| async move { + match statement.inner { + NotStarted(mut stmt @ Some(_)) => { + let mut stmt = stmt + .take() + .expect("It must be there because we checked that above"); + match stmt.step(true).await { + Ok(true) => { + let field_count = stmt.column_count() as usize; + statement.field_count = field_count; + let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); + let new_inner = inner.clone(); + Some(( + Ok(SqliteRow { inner, field_count }), + Self { + inner: Started(new_inner), + ..statement + }, + )) + } + Ok(false) => None, + Err(e) => Some(( + Err(e), + Self { + inner: NotStarted(Some(stmt)), + ..statement + }, + )), + } + // res.poll_next(cx).map(|t| t.flatten()) + } + Started(ref mut last_row) => { + // There was already at least one iteration step + // We check here if the caller already released the row value or not + // by checking if our Rc owns the data or not + if let Some(last_row_ref) = Rc::get_mut(last_row) { + // We own the statement, there is no other reference here. + // This means we don't need to copy out values from the sqlite provided + // datastructures for now + // We don't need to use the runtime borrowing system of the RefCell here + // as we have a mutable reference, so all of this below is checked at compile time + if let PrivateSqliteRow::Direct(ref mut stmt) = last_row_ref.get_mut() { + // This is actually safe here as we've already + // performed one step. For the first step we would have + // used `StatementStreamState::NotStarted` where we don't + // have access to `PrivateSqliteRow` at all + match stmt.step(false).await { + Err(e) => Some(( + Err(e), + Self { + inner: Started(Rc::clone(last_row)), + ..statement + }, + )), + Ok(false) => None, + Ok(true) => { + let field_count = statement.field_count; + Some(( + Ok(SqliteRow { + inner: Rc::clone(last_row), + field_count, + }), + Self { + inner: Started(Rc::clone(last_row)), + ..statement + }, + )) + } + } + } else { + // any other state than `PrivateSqliteRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + } else { + let res = Self::handle_duplicated_row_case( + last_row, + &mut statement.column_names, + statement.field_count, + ) + .await; + res.map(|r| { + ( + r, + Self { + inner: Started(Rc::clone(last_row)), + ..statement + }, + ) + }) + } + } + NotStarted(_s) => { + // we likely got an error while executing the other + // `NotStarted` branch above. In this case we just want to stop + // iterating here + None + } + } + }); + Box::pin(stream) + } +} diff --git a/diesel-wasm-sqlite/src/connection/stmt.rs b/diesel-wasm-sqlite/src/connection/stmt.rs index 53fb5ba78..ec8af5cff 100644 --- a/diesel-wasm-sqlite/src/connection/stmt.rs +++ b/diesel-wasm-sqlite/src/connection/stmt.rs @@ -13,7 +13,7 @@ use diesel::{ Instrumentation, }, query_builder::{QueryFragment, QueryId}, - result::{Error::DatabaseError, *}, + result::{Error, QueryResult}, }; use std::cell::OnceCell; use std::sync::Mutex; diff --git a/diesel-wasm-sqlite/src/query_builder/mod.rs b/diesel-wasm-sqlite/src/query_builder/mod.rs index 8bf48119b..3489876b8 100644 --- a/diesel-wasm-sqlite/src/query_builder/mod.rs +++ b/diesel-wasm-sqlite/src/query_builder/mod.rs @@ -6,7 +6,7 @@ use diesel::result::QueryResult; mod limit_offset; mod query_fragment_impls; -mod returning; +// mod returning; /// Constructs SQL queries for use with the SQLite backend #[allow(missing_debug_implementations)]