Skip to content

Commit

Permalink
Speedup x2
Browse files Browse the repository at this point in the history
  • Loading branch information
AlSchlo committed Apr 24, 2024
1 parent 7574267 commit 7e372b8
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 41 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions optd-datafusion-repr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ assert_approx_eq = "1.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_with = {version = "3.7.0", features = ["json"]}
bincode = "1.3.3"
rayon = "1.10"
73 changes: 40 additions & 33 deletions optd-datafusion-repr/src/cost/base_cost/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use optd_gungnir::stats::{
tdigest::{self, TDigest},
};
use ordered_float::OrderedFloat;
use rayon::prelude::*;
use serde::{de::DeserializeOwned, Deserialize, Serialize};

// The "standard" concrete types that optd currently uses.
Expand Down Expand Up @@ -275,20 +276,23 @@ impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
hlls: &mut [HyperLogLog<ColumnCombValue>],
null_counts: &mut [i32],
) {
for (idx, column_comb) in column_combs.iter().enumerate() {
// TODO(Alexis): Redundant copy.
let filtered_nulls: Vec<ColumnCombValue> = column_comb
.iter()
.filter(|row| row.iter().any(|val| val.is_some()))
.cloned()
.collect();
let nb_rows: i32 = column_comb.len() as i32;

null_counts[idx] += nb_rows - filtered_nulls.len() as i32;

mgs[idx].aggregate(&filtered_nulls);
hlls[idx].aggregate(&filtered_nulls);
}
column_combs
.par_iter()
.zip(mgs)
.zip(hlls)
.zip(null_counts)
.for_each(|(((column_comb, mg), hll), count)| {
let filtered_nulls: Vec<ColumnCombValue> = column_comb
.iter()
.filter(|row| row.iter().any(|val| val.is_some()))
.cloned()
.collect();
let nb_rows = column_comb.len() as i32;

*count += nb_rows - filtered_nulls.len() as i32;
mg.aggregate(&filtered_nulls);
hll.aggregate(&filtered_nulls);
});
}

fn generate_full_stats(
Expand All @@ -297,25 +301,28 @@ impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
distrs: &mut [Option<TDigest<Value>>],
row_counts: &mut [i32],
) {
for (idx, column_comb) in column_combs.iter().enumerate() {
let nb_rows: i32 = column_comb.len() as i32;
row_counts[idx] += nb_rows;

cnts[idx].aggregate(column_comb);
if let Some(distr) = &mut distrs[idx] {
// TODO(Alexis): Redundant copy.
// We project it down to 1D, as we do not support nD TDigests.
let single_col_filtered = column_comb
.iter()
.filter(|row| !cnts[idx].is_tracking(row))
.filter_map(|row| row[0].as_ref())
.cloned()
.collect_vec();

distr.norm_weight += nb_rows as usize;
distr.merge_values(&single_col_filtered);
}
}
column_combs
.par_iter()
.zip(cnts)
.zip(distrs)
.zip(row_counts)
.for_each(|(((column_comb, cnt), distr), count)| {
let nb_rows = column_comb.len() as i32;
*count += nb_rows;
cnt.aggregate(column_comb);

if let Some(d) = distr.as_mut() {
let filtered_values: Vec<_> = column_comb
.iter()
.filter(|row| !cnt.is_tracking(row))
.filter_map(|row| row.get(0).and_then(|v| v.as_ref()))
.cloned()
.collect();

d.norm_weight += nb_rows as usize;
d.merge_values(&filtered_values);
}
});
}

pub fn from_record_batches<I: IntoIterator<Item = Result<RecordBatch, ArrowError>>>(
Expand Down
17 changes: 9 additions & 8 deletions optd-perftest/src/datafusion_dbms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion::{
sql::{parser::DFParser, sqlparser::dialect::GenericDialect},
};
use datafusion_optd_cli::helper::unescape_input;
use itertools::iproduct;
use lazy_static::lazy_static;
use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner};
use optd_datafusion_repr::{
Expand Down Expand Up @@ -369,9 +370,9 @@ impl DatafusionDBMS {

let nb_cols = schema.fields().len();
let single_cols = (0..nb_cols).map(|v| vec![v]);
/*let pairwise_cols = iproduct!(0..nb_cols, 0..nb_cols)
.filter(|(i, j)| i != j)
.map(|(i, j)| vec![i, j]);*/
let pairwise_cols = iproduct!(0..nb_cols, 0..nb_cols)
.filter(|(i, j)| i != j)
.map(|(i, j)| vec![i, j]);

base_table_stats.insert(
tbl_name.to_string(),
Expand All @@ -385,7 +386,7 @@ impl DatafusionDBMS {
.unwrap();
Ok(RecordBatchIterator::new(csv_reader1, schema.clone()))
},
single_cols.collect(),
single_cols.chain(pairwise_cols).collect(),
)?,
);
}
Expand Down Expand Up @@ -429,9 +430,9 @@ impl DatafusionDBMS {

let nb_cols = schema.fields().len();
let single_cols = (0..nb_cols).map(|v| vec![v]);
/*let pairwise_cols = iproduct!(0..nb_cols, 0..nb_cols)
.filter(|(i, j)| i != j)
.map(|(i, j)| vec![i, j]);*/
let pairwise_cols = iproduct!(0..nb_cols, 0..nb_cols)
.filter(|(i, j)| i != j)
.map(|(i, j)| vec![i, j]);

base_table_stats.insert(
tbl_name.to_string(),
Expand All @@ -446,7 +447,7 @@ impl DatafusionDBMS {
.unwrap();
Ok(RecordBatchIterator::new(csv_reader1, schema.clone()))
},
single_cols.collect(),
single_cols.chain(pairwise_cols).collect(),
)?,
);
}
Expand Down

0 comments on commit 7e372b8

Please sign in to comment.