Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewmturner committed Dec 6, 2024
1 parent 28cbcc2 commit 0cabdc9
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 50 deletions.
17 changes: 9 additions & 8 deletions src/execution/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,16 @@ impl ExecutionContext {
})
}

/// Apply all enabled extensions to the `SessionContext`
pub async fn register_extensions(&mut self) -> Result<()> {
let ctx = &mut self.session_ctx;
let config = &self.config;
let extensions = enabled_extensions();
// Apply any additional setup to the session context (e.g. registering
// functions)
for extension in &extensions {
extension.register_on_ctx(config, ctx)?;
}
// let ctx = &mut self.session_ctx;
// let config = &self.config;

// let extensions = enabled_extensions();
//
// for extension in &extensions {
// extension.register_on_ctx(config, ctx).await?;
// }

Ok(())
}
Expand Down
34 changes: 28 additions & 6 deletions src/extensions/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! [`DftSessionStateBuilder`] for configuring DataFusion [`SessionState`]
use color_eyre::eyre;
use datafusion::catalog::TableProviderFactory;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::RuntimeEnv;
Expand All @@ -26,6 +27,10 @@ use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;

use crate::config::ExecutionConfig;

use super::{enabled_extensions, Extension};

/// Builds a DataFusion [`SessionState`] with any necessary configuration
///
/// Ideally we would use the DataFusion [`SessionStateBuilder`], but it doesn't
Expand Down Expand Up @@ -88,11 +93,7 @@ impl DftSessionStateBuilder {
}

/// Add a table factory to the list of factories on this builder
pub fn with_table_factory(
mut self,
name: &str,
factory: Arc<dyn TableProviderFactory>,
) -> Self {
pub fn add_table_factory(&mut self, name: &str, factory: Arc<dyn TableProviderFactory>) {
if self.table_factories.is_none() {
self.table_factories = Some(HashMap::from([(name.to_string(), factory)]));
} else {
Expand All @@ -101,7 +102,6 @@ impl DftSessionStateBuilder {
.unwrap()
.insert(name.to_string(), factory);
}
self
}

/// Return the current [`RuntimeEnv`], creating a default if it doesn't exist
Expand All @@ -112,6 +112,28 @@ impl DftSessionStateBuilder {
self.runtime_env.as_ref().unwrap()
}

pub async fn register_extension(
&mut self,
config: ExecutionConfig,
extension: Arc<dyn Extension>,
) -> color_eyre::Result<()> {
extension
.register(config, self)
.await
.map_err(|_| eyre::eyre!("E"))
}

/// Apply all enabled extensions to the `SessionContext`
pub async fn register_extensions(&mut self, config: ExecutionConfig) -> color_eyre::Result<()> {
let extensions = enabled_extensions();

for extension in extensions {
self.register_extension(config.clone(), extension).await?;
}

Ok(())
}

/// Build the [`SessionState`] from the specified configuration
pub fn build(self) -> datafusion_common::Result<SessionState> {
let Self {
Expand Down
8 changes: 5 additions & 3 deletions src/extensions/deltalake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ impl Extension for DeltaLakeExtension {
async fn register(
&self,
_config: ExecutionConfig,
builder: DftSessionStateBuilder,
) -> datafusion_common::Result<DftSessionStateBuilder> {
Ok(builder.with_table_factory("DELTATABLE", Arc::new(DeltaTableFactory {})))
builder: &mut DftSessionStateBuilder,
) -> datafusion_common::Result<()> {
println!("Registering deltalake");
builder.add_table_factory("DELTATABLE", Arc::new(DeltaTableFactory {}));
Ok(())
}
}
6 changes: 3 additions & 3 deletions src/extensions/functions_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ impl Extension for JsonFunctionsExtension {
async fn register(
&self,
_config: ExecutionConfig,
builder: DftSessionStateBuilder,
) -> datafusion_common::Result<DftSessionStateBuilder> {
Ok(builder)
builder: &mut DftSessionStateBuilder,
) -> datafusion_common::Result<()> {
Ok(())
}

fn register_on_ctx(&self, _config: &ExecutionConfig, ctx: &mut SessionContext) -> Result<()> {
Expand Down
11 changes: 6 additions & 5 deletions src/extensions/iceberg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ impl Extension for IcebergExtension {
async fn register(
&self,
_config: ExecutionConfig,
builder: DftSessionStateBuilder,
) -> datafusion_common::Result<DftSessionStateBuilder> {
Ok(builder.with_table_factory("ICEBERG", Arc::new(IcebergTableProviderFactory {})));

let catalog_provider = IcebergCatalogProvider::try_new(catalog).await?;
builder: &mut DftSessionStateBuilder,
) -> datafusion_common::Result<()> {
Ok(())
// Ok(builder.with_table_factory("ICEBERG", Arc::new(IcebergTableProviderFactory {})));
//
// let catalog_provider = IcebergCatalogProvider::try_new(catalog).await?;
}
}
22 changes: 11 additions & 11 deletions src/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use crate::config::ExecutionConfig;
use datafusion::common::Result;
use datafusion::prelude::SessionContext;
use std::fmt::Debug;
use std::{fmt::Debug, sync::Arc};

mod builder;
#[cfg(feature = "deltalake")]
Expand All @@ -40,27 +40,27 @@ pub trait Extension: Debug {
async fn register(
&self,
_config: ExecutionConfig,
_builder: DftSessionStateBuilder,
) -> Result<DftSessionStateBuilder>;
_builder: &mut DftSessionStateBuilder,
) -> Result<()>;

/// Registers this extension after the SessionContext has been created
/// (this is to match the historic way many extensions were registered)
/// TODO file a ticket upstream to use the builder pattern
// Registers this extension after the SessionContext has been created
// (this is to match the historic way many extensions were registered)
// TODO file a ticket upstream to use the builder pattern
fn register_on_ctx(&self, _config: &ExecutionConfig, _ctx: &mut SessionContext) -> Result<()> {
Ok(())
}
}

/// Return all extensions currently enabled
pub fn enabled_extensions() -> Vec<Box<dyn Extension>> {
pub fn enabled_extensions() -> Vec<Arc<dyn Extension>> {
vec![
#[cfg(feature = "s3")]
Box::new(s3::AwsS3Extension::new()),
Arc::new(s3::AwsS3Extension::new()),
#[cfg(feature = "deltalake")]
Box::new(deltalake::DeltaLakeExtension::new()),
Arc::new(deltalake::DeltaLakeExtension::new()),
#[cfg(feature = "iceberg")]
Box::new(iceberg::IcebergExtension::new()),
Arc::new(iceberg::IcebergExtension::new()),
#[cfg(feature = "functions-json")]
Box::new(functions_json::JsonFunctionsExtension::new()),
Arc::new(functions_json::JsonFunctionsExtension::new()),
]
}
10 changes: 5 additions & 5 deletions src/extensions/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ impl Extension for AwsS3Extension {
async fn register(
&self,
config: ExecutionConfig,
mut builder: DftSessionStateBuilder,
) -> datafusion_common::Result<DftSessionStateBuilder> {
builder: &mut DftSessionStateBuilder,
) -> datafusion_common::Result<()> {
let Some(object_store_config) = &config.object_store else {
return Ok(builder);
return Ok(());
};

let Some(s3_configs) = &object_store_config.s3 else {
return Ok(builder);
return Ok(());
};

info!("S3 configs exists");
Expand All @@ -70,6 +70,6 @@ impl Extension for AwsS3Extension {
}
}

Ok(builder)
Ok(())
}
}
8 changes: 6 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use dft::cli::CliApp;
#[cfg(feature = "flightsql")]
use dft::execution::flightsql::FlightSQLContext;
use dft::execution::{local::ExecutionContext, AppExecution, AppType};
use dft::extensions::DftSessionStateBuilder;
#[cfg(feature = "experimental-flightsql-server")]
use dft::server::FlightSqlApp;
use dft::telemetry;
Expand Down Expand Up @@ -56,9 +57,12 @@ async fn app_entry_point(cli: DftArgs, state: AppState<'_>) -> Result<()> {
const DEFAULT_SERVER_ADDRESS: &str = "127.0.0.1:50051";
info!("Starting FlightSQL server on {}", DEFAULT_SERVER_ADDRESS);
let state = state::initialize(cli.config_path());
let mut execution_ctx =
let mut session_state = DftSessionStateBuilder::new();
session_state
.register_extensions(state.config.execution.clone())
.await;
let execution_ctx =
ExecutionContext::try_new(&state.config.execution, AppType::FlightSQLServer)?;
execution_context.register_extensions().await?;
if cli.run_ddl {
execution_ctx.execute_ddl().await;
}
Expand Down
3 changes: 2 additions & 1 deletion tests/extension_cases/deltalake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ use crate::extension_cases::TestExecution;

#[tokio::test(flavor = "multi_thread")]
async fn test_deltalake() {
let test_exec = TestExecution::new();
let mut test_exec = TestExecution::new();
test_exec.register_extensions().await;

let cwd = std::env::current_dir().unwrap();
let path = Url::from_file_path(cwd.join("data/deltalake/simple_table")).unwrap();
Expand Down
16 changes: 11 additions & 5 deletions tests/extension_cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,20 @@ impl TestExecution {
pub fn new() -> Self {
let config = AppConfig::default();

let mut execution = ExecutionContext::try_new(&config.execution, AppType::Cli).unwrap();
let fut = execution.register_extensions();
tokio::task::block_in_place(move || {
tokio::runtime::Handle::current().block_on(fut).unwrap()
});
let execution = ExecutionContext::try_new(&config.execution, AppType::Cli).unwrap();
// let fut = execution.register_extensions();
// tokio::task::block_in_place(move || {
// tokio::runtime::Handle::current().block_on(fut).unwrap()
// });
Self { execution }
}

/// Register extensions to `SessionContext`
pub async fn register_extensions(&mut self) {
println!("Registerinbg extensions in TestExecution");
self.execution.register_extensions().await.unwrap()
}

/// Run the setup SQL query, discarding the result
#[allow(dead_code)]
pub async fn with_setup(self, sql: &str) -> Self {
Expand Down
2 changes: 1 addition & 1 deletion tests/extension_cases/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use assert_cmd::Command;

use crate::{cli_cases::contains_str, config::TestConfigBuilder};

#[test(flavor = "multi_thread")]
#[tokio::test(flavor = "multi_thread")]
fn test_s3_basic() {
let tempdir = tempfile::tempdir().unwrap();
let ddl_path = tempdir.path().join("my_ddl.sql");
Expand Down

0 comments on commit 0cabdc9

Please sign in to comment.