Skip to content

Commit

Permalink
fix filtered inputs to aggregates
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 25, 2024
1 parent 0daf050 commit d5895c9
Showing 1 changed file with 72 additions and 22 deletions.
94 changes: 72 additions & 22 deletions datafusion-examples/examples/cache_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,39 @@

use std::sync::Arc;

use arrow::util::pretty::print_batches;
use datafusion::arrow::array::{UInt64Array, UInt8Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::ScalarValue;
use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::prelude::SessionContext;
use arrow::util::pretty::print_batches;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
use datafusion::physical_plan::{collect, ExecutionPlan};
use datafusion::logical_expr::Operator;
use datafusion::physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::{collect, ExecutionPlan};
use datafusion::prelude::SessionContext;

#[tokio::main]
async fn main() -> Result<()> {
let mem_table = create_memtable()?;
let input_batch = create_data();
let input = Arc::new(MemoryExec::try_new(
&vec![vec![input_batch.clone()]],
input_batch.schema(),
None,
)?);

// create local execution context
let ctx = SessionContext::new();

// Register the in-memory table containing the data
ctx.register_table("users", Arc::new(mem_table))?;
ctx.register_batch("users", input_batch.clone())?;
let df_old = ctx.sql("SELECT avg(foo) FROM users where id=1;").await?;
let batches = df_old.clone().collect().await?;
print_batches(&batches)?;
Expand All @@ -46,13 +58,27 @@ async fn main() -> Result<()> {
let plan = df_old.clone().create_physical_plan().await?;
let exec = plan.as_any().downcast_ref::<AggregateExec>().unwrap();

let filtered_input = Arc::new(
FilterExec::try_new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("id", 0)),
Operator::Eq,
Arc::new(Literal::new(1u8.into())),
)),
input.clone(),
)?
.with_projection(Some(vec![1]))?,
);

let input_schema = filtered_input.schema();

Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
exec.group_expr().clone(),
exec.aggr_expr().to_vec(),
exec.filter_expr().to_vec(),
exec.input().clone(),
exec.input_schema(),
vec![None],
filtered_input,
input_schema,
)?)
};

Expand All @@ -68,14 +94,33 @@ async fn main() -> Result<()> {
// dbg!(&exec);
// dbg!(exec.input());

let agg_new = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
exec.group_expr().clone(),
exec.aggr_expr().to_vec(),
exec.filter_expr().to_vec(),
exec.input().clone(),
exec.input_schema(),
)?);
let agg_new = {
let plan = dataframe.clone().create_physical_plan().await?;
let exec = plan.as_any().downcast_ref::<AggregateExec>().unwrap();

let filtered_input = Arc::new(
FilterExec::try_new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("id", 0)),
Operator::Eq,
Arc::new(Literal::new(4u8.into())),
)),
input.clone(),
)?
.with_projection(Some(vec![1]))?,
);

let input_schema = filtered_input.schema();

Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
exec.group_expr().clone(),
exec.aggr_expr().to_vec(),
vec![None],
filtered_input,
input_schema,
)?)
};

// let previous_schema = results_previous.first().unwrap().schema().clone();
// let previous_results = MemoryExec::try_new(
Expand All @@ -85,6 +130,12 @@ async fn main() -> Result<()> {
// ).map(Arc::new)?;

let combined_input = Arc::new(UnionExec::new(vec![agg_old, agg_new]));

let task_ctx = Arc::new(dataframe.task_ctx());
let batches = collect(combined_input.clone(), task_ctx).await?;

print_batches(&batches)?;

let combined_input = Arc::new(CoalescePartitionsExec::new(combined_input));
let input_schema = combined_input.schema();
//
Expand All @@ -109,7 +160,7 @@ async fn main() -> Result<()> {
Ok(())
}

fn create_memtable() -> Result<MemTable> {
fn create_data() -> RecordBatch {
let schema = SchemaRef::new(Schema::new(vec![
Field::new("id", DataType::UInt8, false),
Field::new("foo", DataType::UInt64, true),
Expand All @@ -118,10 +169,9 @@ fn create_memtable() -> Result<MemTable> {
let id_array = UInt8Array::from(vec![1, 1, 1, 4, 5]);
let account_array = UInt64Array::from(vec![9000, 9001, 9002, 9003, 9004]);

let b = RecordBatch::try_new(
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id_array), Arc::new(account_array)],
)?;

MemTable::try_new(schema, vec![vec![b]])
)
.expect("Error creating RecordBatch")
}

0 comments on commit d5895c9

Please sign in to comment.