diff --git a/crates/polars-stream/src/execute.rs b/crates/polars-stream/src/execute.rs index 2d68cae2c90e..74b9c6b7edb5 100644 --- a/crates/polars-stream/src/execute.rs +++ b/crates/polars-stream/src/execute.rs @@ -247,6 +247,18 @@ pub fn execute_graph( let num_pipelines = POOL.current_num_threads(); async_executor::set_num_threads(num_pipelines); + // Ensure everything is properly connected. + for (node_key, node) in &graph.nodes { + for (i, input) in node.inputs.iter().enumerate() { + assert!(graph.pipes[*input].receiver == node_key); + assert!(graph.pipes[*input].recv_port == i); + } + for (i, output) in node.outputs.iter().enumerate() { + assert!(graph.pipes[*output].sender == node_key); + assert!(graph.pipes[*output].send_port == i); + } + } + for node in graph.nodes.values_mut() { node.compute.initialize(num_pipelines); } diff --git a/crates/polars-stream/src/graph.rs b/crates/polars-stream/src/graph.rs index 572c1f1c306d..3129653df1a6 100644 --- a/crates/polars-stream/src/graph.rs +++ b/crates/polars-stream/src/graph.rs @@ -1,5 +1,5 @@ use polars_error::PolarsResult; -use slotmap::{SecondaryMap, SlotMap}; +use slotmap::{Key, SecondaryMap, SlotMap}; use crate::nodes::ComputeNode; @@ -32,7 +32,7 @@ impl Graph { pub fn add_node( &mut self, node: N, - inputs: impl IntoIterator, + inputs: impl IntoIterator, ) -> GraphNodeKey { // Add the GraphNode. let node_key = self.nodes.insert(GraphNode { @@ -42,8 +42,7 @@ impl Graph { }); // Create and add pipes that connect input to output. - for (recv_port, sender) in inputs.into_iter().enumerate() { - let send_port = self.nodes[sender].outputs.len(); + for (recv_port, (sender, send_port)) in inputs.into_iter().enumerate() { let pipe = LogicalPipe { sender, send_port, @@ -58,7 +57,13 @@ impl Graph { // And connect input to output. self.nodes[node_key].inputs.push(pipe_key); - self.nodes[sender].outputs.push(pipe_key); + if self.nodes[sender].outputs.len() <= send_port { + self.nodes[sender] + .outputs + .resize(send_port + 1, LogicalPipeKey::null()); + } + assert!(self.nodes[sender].outputs[send_port].is_null()); + self.nodes[sender].outputs[send_port] = pipe_key; } node_key @@ -142,14 +147,14 @@ pub struct LogicalPipe { pub sender: GraphNodeKey, // Output location: // graph[x].output[i].send_port == i - send_port: usize, + pub send_port: usize, pub send_state: PortState, // Node that we receive data from. pub receiver: GraphNodeKey, // Input location: // graph[x].inputs[i].recv_port == i - recv_port: usize, + pub recv_port: usize, pub recv_state: PortState, } diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 11c701abdbbf..0c829e0ec029 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -251,10 +251,10 @@ fn visualize_plan_rec( label )); for input in inputs { - visualize_plan_rec(*input, phys_sm, expr_arena, visited, out); + visualize_plan_rec(input.node, phys_sm, expr_arena, visited, out); out.push(format!( "{} -> {};", - input.data().as_ffi(), + input.node.data().as_ffi(), node_key.data().as_ffi() )); } diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index a4aaa7201818..1df9b0e35d51 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -17,7 +17,7 @@ use polars_utils::pl_str::PlSmallStr; use polars_utils::{format_pl_smallstr, unitvec}; use slotmap::SlotMap; -use super::{PhysNode, PhysNodeKey, PhysNodeKind}; +use super::{PhysNode, PhysNodeKey, PhysNodeKind, PhysStream}; type IRNodeKey = Node; @@ -228,68 +228,13 @@ fn build_input_independent_node_with_ctx( ))) } -fn simplify_input_nodes( - orig_input: PhysNodeKey, - mut input_nodes: PlHashSet, - ctx: &mut LowerExprContext, -) -> PolarsResult> { - // Flatten nested zips (ensures the original input columns only occur once). - if input_nodes.len() > 1 { - let mut flattened_input_nodes = PlHashSet::with_capacity(input_nodes.len()); - for input_node in input_nodes { - if let PhysNodeKind::Zip { - inputs, - null_extend: false, - } = &ctx.phys_sm[input_node].kind - { - flattened_input_nodes.extend(inputs); - ctx.phys_sm.remove(input_node); - } else { - flattened_input_nodes.insert(input_node); - } - } - input_nodes = flattened_input_nodes; - } - - // Merge reduce nodes that directly operate on the original input. - let mut combined_exprs = vec![]; - input_nodes = input_nodes - .into_iter() - .filter(|input_node| { - if let PhysNodeKind::Reduce { - input: inner, - exprs, - } = &ctx.phys_sm[*input_node].kind - { - if *inner == orig_input { - combined_exprs.extend(exprs.iter().cloned()); - ctx.phys_sm.remove(*input_node); - return false; - } - } - true - }) - .collect(); - if !combined_exprs.is_empty() { - let output_schema = schema_for_select(orig_input, &combined_exprs, ctx)?; - let kind = PhysNodeKind::Reduce { - input: orig_input, - exprs: combined_exprs, - }; - let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); - input_nodes.insert(reduce_node_key); - } - - Ok(input_nodes) -} - fn build_fallback_node_with_ctx( - input: PhysNodeKey, + input: PhysStream, exprs: &[ExprIR], ctx: &mut LowerExprContext, ) -> PolarsResult { // Pre-select only the columns that are needed for this fallback expression. - let input_schema = &ctx.phys_sm[input].output_schema; + let input_schema = &ctx.phys_sm[input.node].output_schema; let mut select_names: PlHashSet<_> = exprs .iter() .flat_map(|expr| polars_plan::utils::aexpr_to_leaf_names_iter(expr.node(), ctx.expr_arena)) @@ -301,7 +246,7 @@ fn build_fallback_node_with_ctx( select_names.insert(name.clone()); } } - let input_node = if input_schema + let input_stream = if input_schema .iter_names() .any(|name| !select_names.contains(name.as_str())) { @@ -314,12 +259,12 @@ fn build_fallback_node_with_ctx( ) }) .collect_vec(); - build_select_node_with_ctx(input, &select_exprs, ctx)? + build_select_stream_with_ctx(input, &select_exprs, ctx)? } else { input }; - let output_schema = schema_for_select(input_node, exprs, ctx)?; + let output_schema = schema_for_select(input_stream, exprs, ctx)?; let expr_depth_limit = get_expr_depth_limit()?; let mut conv_state = ExpressionConversionState::new(false, expr_depth_limit); let phys_exprs = exprs @@ -329,7 +274,7 @@ fn build_fallback_node_with_ctx( expr, Context::Default, ctx.expr_arena, - &ctx.phys_sm[input_node].output_schema, + &ctx.phys_sm[input_stream.node].output_schema, &mut conv_state, ) }) @@ -343,20 +288,75 @@ fn build_fallback_node_with_ctx( DataFrame::new_with_broadcast(columns) }; let kind = PhysNodeKind::InMemoryMap { - input: input_node, + input: input_stream, map: Arc::new(map), }; Ok(ctx.phys_sm.insert(PhysNode::new(output_schema, kind))) } +fn simplify_input_streams( + orig_input: PhysStream, + mut input_streams: PlHashSet, + ctx: &mut LowerExprContext, +) -> PolarsResult> { + // Flatten nested zips (ensures the original input columns only occur once). + if input_streams.len() > 1 { + let mut flattened_input_streams = PlHashSet::with_capacity(input_streams.len()); + for input_stream in input_streams { + if let PhysNodeKind::Zip { + inputs, + null_extend: false, + } = &ctx.phys_sm[input_stream.node].kind + { + flattened_input_streams.extend(inputs); + ctx.phys_sm.remove(input_stream.node); + } else { + flattened_input_streams.insert(input_stream); + } + } + input_streams = flattened_input_streams; + } + + // Merge reduce nodes that directly operate on the original input. + let mut combined_exprs = vec![]; + input_streams = input_streams + .into_iter() + .filter(|input_stream| { + if let PhysNodeKind::Reduce { + input: inner, + exprs, + } = &ctx.phys_sm[input_stream.node].kind + { + if *inner == orig_input { + combined_exprs.extend(exprs.iter().cloned()); + ctx.phys_sm.remove(input_stream.node); + return false; + } + } + true + }) + .collect(); + if !combined_exprs.is_empty() { + let output_schema = schema_for_select(orig_input, &combined_exprs, ctx)?; + let kind = PhysNodeKind::Reduce { + input: orig_input, + exprs: combined_exprs, + }; + let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); + input_streams.insert(PhysStream::first(reduce_node_key)); + } + + Ok(input_streams) +} + // In the recursive lowering we don't bother with named expressions at all, so // we work directly with Nodes. #[recursive::recursive] fn lower_exprs_with_ctx( - input: PhysNodeKey, + input: PhysStream, exprs: &[Node], ctx: &mut LowerExprContext, -) -> PolarsResult<(PhysNodeKey, Vec)> { +) -> PolarsResult<(PhysStream, Vec)> { // We have to catch this case separately, in case all the input independent expressions are elementwise. // TODO: we shouldn't always do this when recursing, e.g. pl.col.a.sum() + 1 will still hit this in the recursion. if exprs.iter().all(|e| is_input_independent(*e, ctx)) { @@ -369,14 +369,14 @@ fn lower_exprs_with_ctx( .iter() .map(|e| ctx.expr_arena.add(AExpr::Column(e.output_name().clone()))) .collect(); - return Ok((node, out_exprs)); + return Ok((PhysStream::first(node), out_exprs)); } // Fallback expressions that can directly be applied to the original input. let mut fallback_subset = Vec::new(); - // Nodes containing the columns used for executing transformed expressions. - let mut input_nodes = PlHashSet::new(); + // Streams containing the columns used for executing transformed expressions. + let mut input_streams = PlHashSet::new(); // The final transformed expressions that will be selected from the zipped // together transformed nodes. @@ -385,7 +385,7 @@ fn lower_exprs_with_ctx( for expr in exprs.iter().copied() { if is_elementwise_rec_cached(expr, ctx.expr_arena, ctx.cache) { if !is_input_independent(expr, ctx) { - input_nodes.insert(input); + input_streams.insert(input); } transformed_exprs.push(expr); continue; @@ -407,7 +407,7 @@ fn lower_exprs_with_ctx( extend_original: false, }; let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind)); - input_nodes.insert(node_key); + input_streams.insert(PhysStream::first(node_key)); transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(exploded_name))); }, AExpr::Alias(_, _) => unreachable!("alias found in physical plan"), @@ -415,7 +415,8 @@ fn lower_exprs_with_ctx( AExpr::Literal(_) => { let out_name = unique_column_name(); let inner_expr = ExprIR::new(expr, OutputName::Alias(out_name.clone())); - input_nodes.insert(build_input_independent_node_with_ctx(&[inner_expr], ctx)?); + let node_key = build_input_independent_node_with_ctx(&[inner_expr], ctx)?; + input_streams.insert(PhysStream::first(node_key)); transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); }, AExpr::BinaryExpr { left, op, right } => { @@ -425,7 +426,7 @@ fn lower_exprs_with_ctx( op, right: trans_exprs[1], }; - input_nodes.insert(trans_input); + input_streams.insert(trans_input); transformed_exprs.push(ctx.expr_arena.add(bin_expr)); }, AExpr::Ternary { @@ -440,7 +441,7 @@ fn lower_exprs_with_ctx( truthy: trans_exprs[1], falsy: trans_exprs[2], }; - input_nodes.insert(trans_input); + input_streams.insert(trans_input); transformed_exprs.push(ctx.expr_arena.add(tern_expr)); }, AExpr::Cast { @@ -449,7 +450,7 @@ fn lower_exprs_with_ctx( options, } => { let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[inner], ctx)?; - input_nodes.insert(trans_input); + input_streams.insert(trans_input); transformed_exprs.push(ctx.expr_arena.add(AExpr::Cast { expr: trans_exprs[0], dtype, @@ -464,17 +465,18 @@ fn lower_exprs_with_ctx( // expr is available as a column by selecting first. let sorted_name = unique_column_name(); let inner_expr_ir = ExprIR::new(inner, OutputName::Alias(sorted_name.clone())); - let select_node = build_select_node_with_ctx(input, &[inner_expr_ir.clone()], ctx)?; + let select_stream = + build_select_stream_with_ctx(input, &[inner_expr_ir.clone()], ctx)?; let col_expr = ctx.expr_arena.add(AExpr::Column(sorted_name.clone())); let kind = PhysNodeKind::Sort { - input: select_node, + input: select_stream, by_column: vec![ExprIR::new(col_expr, OutputName::Alias(sorted_name))], slice: None, sort_options: (&options).into(), }; - let output_schema = ctx.phys_sm[select_node].output_schema.clone(); + let output_schema = ctx.phys_sm[select_stream.node].output_schema.clone(); let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); - input_nodes.insert(node_key); + input_streams.insert(PhysStream::first(node_key)); transformed_exprs.push(col_expr); }, AExpr::SortBy { @@ -490,11 +492,11 @@ fn lower_exprs_with_ctx( .chain(by_names.iter().zip(by.iter().copied())) .map(|(name, inner)| ExprIR::new(inner, OutputName::Alias(name.clone()))) .collect_vec(); - let select_node = build_select_node_with_ctx(input, &all_inner_expr_irs, ctx)?; + let select_stream = build_select_stream_with_ctx(input, &all_inner_expr_irs, ctx)?; // Sort the inputs. let kind = PhysNodeKind::Sort { - input: select_node, + input: select_stream, by_column: by_names .into_iter() .map(|name| { @@ -507,16 +509,17 @@ fn lower_exprs_with_ctx( slice: None, sort_options, }; - let output_schema = ctx.phys_sm[select_node].output_schema.clone(); + let output_schema = ctx.phys_sm[select_stream.node].output_schema.clone(); let sort_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); + let sort_stream = PhysStream::first(sort_node_key); // Drop the by columns. let sorted_col_expr = ctx.expr_arena.add(AExpr::Column(sorted_name.clone())); let sorted_col_ir = ExprIR::new(sorted_col_expr, OutputName::Alias(sorted_name.clone())); - let post_sort_select_node = - build_select_node_with_ctx(sort_node_key, &[sorted_col_ir], ctx)?; - input_nodes.insert(post_sort_select_node); + let post_sort_select_stream = + build_select_stream_with_ctx(sort_stream, &[sorted_col_ir], ctx)?; + input_streams.insert(post_sort_select_stream); transformed_exprs.push(sorted_col_expr); }, AExpr::Gather { .. } => todo!(), @@ -526,8 +529,8 @@ fn lower_exprs_with_ctx( let by_name = unique_column_name(); let inner_expr_ir = ExprIR::new(inner, OutputName::Alias(out_name.clone())); let by_expr_ir = ExprIR::new(by, OutputName::Alias(by_name.clone())); - let select_node = - build_select_node_with_ctx(input, &[inner_expr_ir, by_expr_ir], ctx)?; + let select_stream = + build_select_stream_with_ctx(input, &[inner_expr_ir, by_expr_ir], ctx)?; // Add a filter node. let predicate = ExprIR::new( @@ -535,12 +538,12 @@ fn lower_exprs_with_ctx( OutputName::Alias(by_name), ); let kind = PhysNodeKind::Filter { - input: select_node, + input: select_stream, predicate, }; - let output_schema = ctx.phys_sm[select_node].output_schema.clone(); + let output_schema = ctx.phys_sm[select_stream.node].output_schema.clone(); let filter_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); - input_nodes.insert(filter_node_key); + input_streams.insert(PhysStream::first(filter_node_key)); transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); }, AExpr::Agg(mut agg) => match agg { @@ -569,7 +572,7 @@ fn lower_exprs_with_ctx( exprs: vec![expr_ir], }; let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); - input_nodes.insert(reduce_node_key); + input_streams.insert(PhysStream::first(reduce_node_key)); transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); }, IRAggExpr::Median(_) @@ -594,7 +597,7 @@ fn lower_exprs_with_ctx( exprs: vec![expr_ir], }; let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); - input_nodes.insert(reduce_node_key); + input_streams.insert(PhysStream::first(reduce_node_key)); transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); }, AExpr::AnonymousFunction { .. } @@ -609,22 +612,23 @@ fn lower_exprs_with_ctx( } if !fallback_subset.is_empty() { - input_nodes.insert(build_fallback_node_with_ctx(input, &fallback_subset, ctx)?); + let fallback_node = build_fallback_node_with_ctx(input, &fallback_subset, ctx)?; + input_streams.insert(PhysStream::first(fallback_node)); } // Simplify the input nodes (also ensures the original input only occurs // once in the zip). - input_nodes = simplify_input_nodes(input, input_nodes, ctx)?; + input_streams = simplify_input_streams(input, input_streams, ctx)?; - if input_nodes.len() == 1 { + if input_streams.len() == 1 { // No need for any multiplexing/zipping, can directly execute. - return Ok((input_nodes.into_iter().next().unwrap(), transformed_exprs)); + return Ok((input_streams.into_iter().next().unwrap(), transformed_exprs)); } - let zip_inputs = input_nodes.into_iter().collect_vec(); + let zip_inputs = input_streams.into_iter().collect_vec(); let output_schema = zip_inputs .iter() - .flat_map(|node| ctx.phys_sm[*node].output_schema.iter_fields()) + .flat_map(|stream| ctx.phys_sm[stream.node].output_schema.iter_fields()) .collect(); let zip_kind = PhysNodeKind::Zip { inputs: zip_inputs, @@ -634,7 +638,7 @@ fn lower_exprs_with_ctx( .phys_sm .insert(PhysNode::new(Arc::new(output_schema), zip_kind)); - Ok((zip_node, transformed_exprs)) + Ok((PhysStream::first(zip_node), transformed_exprs)) } /// Computes the schema that selecting the given expressions on the input schema @@ -662,21 +666,23 @@ pub fn compute_output_schema( /// Computes the schema that selecting the given expressions on the input node /// would result in. fn schema_for_select( - input: PhysNodeKey, + input: PhysStream, exprs: &[ExprIR], ctx: &mut LowerExprContext, ) -> PolarsResult> { - let input_schema = &ctx.phys_sm[input].output_schema; + let input_schema = &ctx.phys_sm[input.node].output_schema; compute_output_schema(input_schema, exprs, ctx.expr_arena) } -fn build_select_node_with_ctx( - input: PhysNodeKey, +fn build_select_stream_with_ctx( + input: PhysStream, exprs: &[ExprIR], ctx: &mut LowerExprContext, -) -> PolarsResult { +) -> PolarsResult { if exprs.iter().all(|e| is_input_independent(e.node(), ctx)) { - return build_input_independent_node_with_ctx(exprs, ctx); + return Ok(PhysStream::first(build_input_independent_node_with_ctx( + exprs, ctx, + )?)); } // Are we only selecting simple columns, with the same name? @@ -689,7 +695,7 @@ fn build_select_node_with_ctx( .collect(); if let Some(columns) = all_simple_columns { - let input_schema = ctx.phys_sm[input].output_schema.clone(); + let input_schema = ctx.phys_sm[input.node].output_schema.clone(); if !cfg!(debug_assertions) && input_schema.len() == columns.len() && input_schema.iter_names().zip(&columns).all(|(l, r)| l == r) @@ -700,7 +706,8 @@ fn build_select_node_with_ctx( let output_schema = Arc::new(input_schema.try_project(&columns)?); let node_kind = PhysNodeKind::SimpleProjection { input, columns }; - return Ok(ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind))); + let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind)); + return Ok(PhysStream::first(node_key)); } let node_exprs = exprs.iter().map(|e| e.node()).collect_vec(); @@ -716,7 +723,8 @@ fn build_select_node_with_ctx( selectors: trans_expr_irs, extend_original: false, }; - Ok(ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind))) + let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind)); + Ok(PhysStream::first(node_key)) } /// Lowers an input node plus a set of expressions on that input node to an @@ -725,12 +733,12 @@ fn build_select_node_with_ctx( /// /// Ensures that if the input node is transformed it has unique column names. pub fn lower_exprs( - input: PhysNodeKey, + input: PhysStream, exprs: &[ExprIR], expr_arena: &mut Arena, phys_sm: &mut SlotMap, expr_cache: &mut ExprCache, -) -> PolarsResult<(PhysNodeKey, Vec)> { +) -> PolarsResult<(PhysStream, Vec)> { let mut ctx = LowerExprContext { expr_arena, phys_sm, @@ -747,18 +755,19 @@ pub fn lower_exprs( Ok((transformed_input, trans_expr_irs)) } -/// Builds a selection node given an input node and the expressions to select for. -pub fn build_select_node( - input: PhysNodeKey, +/// Builds a new selection node given an input stream and the expressions to +/// select for, if needed. +pub fn build_select_stream( + input: PhysStream, exprs: &[ExprIR], expr_arena: &mut Arena, phys_sm: &mut SlotMap, expr_cache: &mut ExprCache, -) -> PolarsResult { +) -> PolarsResult { let mut ctx = LowerExprContext { expr_arena, phys_sm, cache: expr_cache, }; - build_select_node_with_ctx(input, exprs, &mut ctx) + build_select_stream_with_ctx(input, exprs, &mut ctx) } diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 7b0140751159..649ed49453b1 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -12,41 +12,43 @@ use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; use slotmap::SlotMap; -use super::{PhysNode, PhysNodeKey, PhysNodeKind}; +use super::{PhysNode, PhysNodeKey, PhysNodeKind, PhysStream}; use crate::physical_plan::lower_expr::{ - build_select_node, is_elementwise_rec_cached, lower_exprs, ExprCache, + build_select_stream, is_elementwise_rec_cached, lower_exprs, ExprCache, }; -fn build_slice_node( - input: PhysNodeKey, +/// Creates a new PhysStream which outputs a slice of the input stream. +fn build_slice_stream( + input: PhysStream, offset: i64, length: usize, phys_sm: &mut SlotMap, -) -> PhysNodeKey { +) -> PhysStream { if offset >= 0 { let offset = offset as usize; - phys_sm.insert(PhysNode::new( - phys_sm[input].output_schema.clone(), + PhysStream::first(phys_sm.insert(PhysNode::new( + phys_sm[input.node].output_schema.clone(), PhysNodeKind::StreamingSlice { input, offset, length, }, - )) + ))) } else { todo!() } } -fn build_filter_node( - input: PhysNodeKey, +/// Creates a new PhysStream which is filters the input stream. +fn build_filter_stream( + input: PhysStream, predicate: ExprIR, expr_arena: &mut Arena, phys_sm: &mut SlotMap, expr_cache: &mut ExprCache, -) -> PolarsResult { +) -> PolarsResult { let predicate = predicate.clone(); - let cols_and_predicate = phys_sm[input] + let cols_and_predicate = phys_sm[input.node] .output_schema .iter_names() .cloned() @@ -61,7 +63,7 @@ fn build_filter_node( let (trans_input, mut trans_cols_and_predicate) = lower_exprs(input, &cols_and_predicate, expr_arena, phys_sm, expr_cache)?; - let filter_schema = phys_sm[trans_input].output_schema.clone(); + let filter_schema = phys_sm[trans_input.node].output_schema.clone(); let filter = PhysNodeKind::Filter { input: trans_input, predicate: trans_cols_and_predicate.last().unwrap().clone(), @@ -69,8 +71,8 @@ fn build_filter_node( let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter)); trans_cols_and_predicate.pop(); // Remove predicate. - build_select_node( - post_filter, + build_select_stream( + PhysStream::first(post_filter), &trans_cols_and_predicate, expr_arena, phys_sm, @@ -86,8 +88,8 @@ pub fn lower_ir( phys_sm: &mut SlotMap, schema_cache: &mut PlHashMap>, expr_cache: &mut ExprCache, - cache_nodes: &mut PlHashMap, -) -> PolarsResult { + cache_nodes: &mut PlHashMap, +) -> PolarsResult { // Helper macro to simplify recursive calls. macro_rules! lower_ir { ($input:expr) => { @@ -118,7 +120,7 @@ pub fn lower_ir( IR::Select { input, expr, .. } => { let selectors = expr.clone(); let phys_input = lower_ir!(*input)?; - return build_select_node(phys_input, &selectors, expr_arena, phys_sm, expr_cache); + return build_select_stream(phys_input, &selectors, expr_arena, phys_sm, expr_cache); }, IR::HStack { input, exprs, .. } @@ -143,7 +145,7 @@ pub fn lower_ir( // FIXME: constant literal columns should be broadcasted with hstack. let exprs = exprs.clone(); let phys_input = lower_ir!(*input)?; - let input_schema = &phys_sm[phys_input].output_schema; + let input_schema = &phys_sm[phys_input.node].output_schema; let mut selectors = PlIndexMap::with_capacity(input_schema.len() + exprs.len()); for name in input_schema.iter_names() { let col_name = name.clone(); @@ -157,20 +159,20 @@ pub fn lower_ir( selectors.insert(expr.output_name().clone(), expr); } let selectors = selectors.into_values().collect_vec(); - return build_select_node(phys_input, &selectors, expr_arena, phys_sm, expr_cache); + return build_select_stream(phys_input, &selectors, expr_arena, phys_sm, expr_cache); }, IR::Slice { input, offset, len } => { let offset = *offset; let len = *len as usize; let phys_input = lower_ir!(*input)?; - return Ok(build_slice_node(phys_input, offset, len, phys_sm)); + return Ok(build_slice_stream(phys_input, offset, len, phys_sm)); }, IR::Filter { input, predicate } => { let predicate = predicate.clone(); let phys_input = lower_ir!(*input)?; - return build_filter_node(phys_input, predicate, expr_arena, phys_sm, expr_cache); + return build_filter_stream(phys_input, predicate, expr_arena, phys_sm, expr_cache); }, IR::DataFrameScan { @@ -192,7 +194,7 @@ pub fn lower_ir( { let phys_input = phys_sm.insert(PhysNode::new(schema, node_kind)); node_kind = PhysNodeKind::SimpleProjection { - input: phys_input, + input: PhysStream::first(phys_input), columns: projection_schema.iter_names_cloned().collect::>(), }; } @@ -289,14 +291,15 @@ pub fn lower_ir( .map(|input| lower_ir!(input)) .collect::>()?; - let mut node = phys_sm.insert(PhysNode { + let node = phys_sm.insert(PhysNode { output_schema, kind: PhysNodeKind::OrderedUnion { inputs }, }); + let mut stream = PhysStream::first(node); if let Some((offset, length)) = options.slice { - node = build_slice_node(node, offset, length, phys_sm); + stream = build_slice_stream(stream, offset, length, phys_sm); } - return Ok(node); + return Ok(stream); }, IR::HConcat { @@ -396,10 +399,10 @@ pub fn lower_ir( let v = schema.shift_remove_index(0).unwrap().0; assert_eq!(v, ri.name); - let input = phys_sm.insert(PhysNode::new(Arc::new(schema), node_kind)); + let input_node = phys_sm.insert(PhysNode::new(Arc::new(schema), node_kind)); PhysNodeKind::WithRowIndex { - input, + input: PhysStream::first(input_node), name: ri.name, offset: Some(ri.offset), } @@ -407,20 +410,22 @@ pub fn lower_ir( node_kind }; - let mut node = phys_sm.insert(PhysNode { + let node = phys_sm.insert(PhysNode { output_schema, kind: node_kind, }); + let mut stream = PhysStream::first(node); if let Some((offset, length)) = slice { - node = build_slice_node(node, offset, length, phys_sm); + stream = build_slice_stream(stream, offset, length, phys_sm); } if let Some(predicate) = predicate { - node = build_filter_node(node, predicate, expr_arena, phys_sm, expr_cache)?; + stream = + build_filter_stream(stream, predicate, expr_arena, phys_sm, expr_cache)?; } - return Ok(node); + return Ok(stream); } }, @@ -504,7 +509,7 @@ pub fn lower_ir( }) .collect(); - let mut node = phys_sm.insert(PhysNode::new( + let node = phys_sm.insert(PhysNode::new( output_schema, PhysNodeKind::GroupBy { input: trans_input, @@ -515,10 +520,11 @@ pub fn lower_ir( // TODO: actually limit number of groups instead of computing full // result and then slicing. + let mut stream = PhysStream::first(node); if let Some((offset, len)) = options.slice { - node = build_slice_node(node, offset, len, phys_sm); + stream = build_slice_stream(stream, offset, len, phys_sm); } - return Ok(node); + return Ok(stream); }, IR::Join { input_left, @@ -542,12 +548,12 @@ pub fn lower_ir( // nodes since the lowering code does not see we access any non-literal expressions. // So we add dummy expressions before lowering and remove them afterwards. let mut aug_left_on = left_on.clone(); - for name in phys_sm[phys_left].output_schema.iter_names() { + for name in phys_sm[phys_left.node].output_schema.iter_names() { let col_expr = expr_arena.add(AExpr::Column(name.clone())); aug_left_on.push(ExprIR::new(col_expr, OutputName::ColumnLhs(name.clone()))); } let mut aug_right_on = right_on.clone(); - for name in phys_sm[phys_right].output_schema.iter_names() { + for name in phys_sm[phys_right.node].output_schema.iter_names() { let col_expr = expr_arena.add(AExpr::Column(name.clone())); aug_right_on.push(ExprIR::new(col_expr, OutputName::ColumnLhs(name.clone()))); } @@ -558,7 +564,7 @@ pub fn lower_ir( trans_left_on.drain(left_on.len()..); trans_right_on.drain(right_on.len()..); - let mut node = phys_sm.insert(PhysNode::new( + let node = phys_sm.insert(PhysNode::new( output_schema, PhysNodeKind::EquiJoin { input_left: trans_input_left, @@ -568,10 +574,11 @@ pub fn lower_ir( args: args.clone(), }, )); + let mut stream = PhysStream::first(node); if let Some((offset, len)) = args.slice { - node = build_slice_node(node, offset, len, phys_sm); + stream = build_slice_stream(stream, offset, len, phys_sm); } - return Ok(node); + return Ok(stream); } else { PhysNodeKind::InMemoryJoin { input_left: phys_left, @@ -588,5 +595,6 @@ pub fn lower_ir( IR::Invalid => unreachable!(), }; - Ok(phys_sm.insert(PhysNode::new(output_schema, node_kind))) + let node_key = phys_sm.insert(PhysNode::new(output_schema, node_kind)); + Ok(PhysStream::first(node_key)) } diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index 1d33addb4c4a..f311de368d07 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -20,13 +20,13 @@ pub use fmt::visualize_plan; use polars_plan::prelude::{FileScanOptions, FileType}; use polars_utils::arena::{Arena, Node}; use polars_utils::pl_str::PlSmallStr; -use slotmap::{Key, SecondaryMap, SlotMap}; +use slotmap::{SecondaryMap, SlotMap}; pub use to_graph::physical_plan_to_graph; use crate::physical_plan::lower_expr::ExprCache; slotmap::new_key_type! { - /// Key used for PNodes. + /// Key used for physical nodes. pub struct PhysNodeKey; } @@ -49,6 +49,27 @@ impl PhysNode { } } +/// A handle representing a physical stream of data with a fixed schema in the +/// physical plan. It consists of a reference to a physical node as well as the +/// output port on that node to connect to receive the stream. +#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)] +pub struct PhysStream { + pub node: PhysNodeKey, + pub port: usize, +} + +impl PhysStream { + #[expect(unused)] + pub fn new(node: PhysNodeKey, port: usize) -> Self { + Self { node, port } + } + + // Convenience method to refer to the first output port of a physical node. + pub fn first(node: PhysNodeKey) -> Self { + Self { node, port: 0 } + } +} + #[derive(Clone, Debug)] pub enum PhysNodeKind { InMemorySource { @@ -56,13 +77,13 @@ pub enum PhysNodeKind { }, Select { - input: PhysNodeKey, + input: PhysStream, selectors: Vec, extend_original: bool, }, WithRowIndex { - input: PhysNodeKey, + input: PhysStream, name: PlSmallStr, offset: Option, }, @@ -72,62 +93,62 @@ pub enum PhysNodeKind { }, Reduce { - input: PhysNodeKey, + input: PhysStream, exprs: Vec, }, StreamingSlice { - input: PhysNodeKey, + input: PhysStream, offset: usize, length: usize, }, Filter { - input: PhysNodeKey, + input: PhysStream, predicate: ExprIR, }, SimpleProjection { - input: PhysNodeKey, + input: PhysStream, columns: Vec, }, InMemorySink { - input: PhysNodeKey, + input: PhysStream, }, FileSink { path: Arc, file_type: FileType, - input: PhysNodeKey, + input: PhysStream, }, /// Generic fallback for (as-of-yet) unsupported streaming mappings. /// Fully sinks all data to an in-memory data frame and uses the in-memory /// engine to perform the map. InMemoryMap { - input: PhysNodeKey, + input: PhysStream, map: Arc, }, Map { - input: PhysNodeKey, + input: PhysStream, map: Arc, }, Sort { - input: PhysNodeKey, + input: PhysStream, by_column: Vec, slice: Option<(i64, usize)>, sort_options: SortMultipleOptions, }, OrderedUnion { - inputs: Vec, + inputs: Vec, }, Zip { - inputs: Vec, + inputs: Vec, /// If true shorter inputs are extended with nulls to the longest input, /// if false all inputs must be the same length, or have length 1 in /// which case they are broadcast. @@ -136,7 +157,7 @@ pub enum PhysNodeKind { #[allow(unused)] Multiplexer { - input: PhysNodeKey, + input: PhysStream, }, FileScan { @@ -150,14 +171,14 @@ pub enum PhysNodeKind { }, GroupBy { - input: PhysNodeKey, + input: PhysStream, key: Vec, aggs: Vec, }, EquiJoin { - input_left: PhysNodeKey, - input_right: PhysNodeKey, + input_left: PhysStream, + input_right: PhysStream, left_on: Vec, right_on: Vec, args: JoinArgs, @@ -167,8 +188,8 @@ pub enum PhysNodeKind { /// Fully sinks all data to in-memory data frames and uses the in-memory /// engine to perform the join. InMemoryJoin { - input_left: PhysNodeKey, - input_right: PhysNodeKey, + input_left: PhysStream, + input_right: PhysStream, left_on: Vec, right_on: Vec, args: JoinArgs, @@ -176,34 +197,24 @@ pub enum PhysNodeKind { }, } -#[recursive::recursive] -fn insert_multiplexers( - node: PhysNodeKey, +fn visit_node_inputs_mut( + roots: Vec, phys_sm: &mut SlotMap, - referenced: &mut SecondaryMap, + mut visit: impl FnMut(&mut PhysStream), ) { - let seen_before = referenced.insert(node, ()).is_some(); - if seen_before && !matches!(phys_sm[node].kind, PhysNodeKind::Multiplexer { .. }) { - // This node is referenced at least twice. We first set the input key to - // null and then update it to avoid a double-mutable-borrow issue. - let input_schema = phys_sm[node].output_schema.clone(); - let orig_input_node = core::mem::replace( - &mut phys_sm[node], - PhysNode::new( - input_schema, - PhysNodeKind::Multiplexer { - input: PhysNodeKey::null(), - }, - ), - ); - let orig_input_key = phys_sm.insert(orig_input_node); - phys_sm[node].kind = PhysNodeKind::Multiplexer { - input: orig_input_key, + let mut to_visit = roots; + let mut seen: SecondaryMap = + to_visit.iter().copied().map(|n| (n, ())).collect(); + macro_rules! rec { + ($n:expr) => { + let n = $n; + if seen.insert(n, ()).is_none() { + to_visit.push(n) + } }; } - - if !seen_before { - match &phys_sm[node].kind { + while let Some(node) = to_visit.pop() { + match &mut phys_sm[node].kind { PhysNodeKind::InMemorySource { .. } | PhysNodeKind::FileScan { .. } | PhysNodeKind::InputIndependentSelect { .. } => {}, @@ -220,7 +231,8 @@ fn insert_multiplexers( | PhysNodeKind::Sort { input, .. } | PhysNodeKind::Multiplexer { input } | PhysNodeKind::GroupBy { input, .. } => { - insert_multiplexers(*input, phys_sm, referenced); + rec!(input.node); + visit(input); }, PhysNodeKind::InMemoryJoin { @@ -233,20 +245,49 @@ fn insert_multiplexers( input_right, .. } => { - let input_right = *input_right; - insert_multiplexers(*input_left, phys_sm, referenced); - insert_multiplexers(input_right, phys_sm, referenced); + rec!(input_left.node); + rec!(input_right.node); + visit(input_left); + visit(input_right); }, PhysNodeKind::OrderedUnion { inputs } | PhysNodeKind::Zip { inputs, .. } => { - for input in inputs.clone() { - insert_multiplexers(input, phys_sm, referenced); + for input in inputs { + rec!(input.node); + visit(input); } }, } } } +fn insert_multiplexers(roots: Vec, phys_sm: &mut SlotMap) { + let mut refcount = PlHashMap::new(); + visit_node_inputs_mut(roots.clone(), phys_sm, |i| { + *refcount.entry(*i).or_insert(0) += 1; + }); + + let mut multiplexer_map: PlHashMap = refcount + .into_iter() + .filter(|(_stream, refcount)| *refcount > 1) + .map(|(stream, _refcount)| { + let input_schema = phys_sm[stream.node].output_schema.clone(); + let multiplexer_node = phys_sm.insert(PhysNode::new( + input_schema, + PhysNodeKind::Multiplexer { input: stream }, + )); + (stream, PhysStream::first(multiplexer_node)) + }) + .collect(); + + visit_node_inputs_mut(roots, phys_sm, |i| { + if let Some(m) = multiplexer_map.get_mut(i) { + *i = *m; + m.port += 1; + } + }); +} + pub fn build_physical_plan( root: Node, ir_arena: &mut Arena, @@ -265,7 +306,6 @@ pub fn build_physical_plan( &mut expr_cache, &mut cache_nodes, )?; - let mut referenced = SecondaryMap::with_capacity(phys_sm.capacity()); - insert_multiplexers(phys_root, phys_sm, &mut referenced); - Ok(phys_root) + insert_multiplexers(vec![phys_root.node], phys_sm); + Ok(phys_root.node) } diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 1f6cbecb6cbf..c9182a3eb964 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -102,20 +102,20 @@ fn to_graph_rec<'a>( offset, length, } => { - let input_key = to_graph_rec(*input, ctx)?; + let input_key = to_graph_rec(input.node, ctx)?; ctx.graph.add_node( nodes::streaming_slice::StreamingSliceNode::new(*offset, *length), - [input_key], + [(input_key, input.port)], ) }, Filter { predicate, input } => { - let input_schema = &ctx.phys_sm[*input].output_schema; + let input_schema = &ctx.phys_sm[input.node].output_schema; let phys_predicate_expr = create_stream_expr(predicate, ctx, input_schema)?; - let input_key = to_graph_rec(*input, ctx)?; + let input_key = to_graph_rec(input.node, ctx)?; ctx.graph.add_node( nodes::filter::FilterNode::new(phys_predicate_expr), - [input_key], + [(input_key, input.port)], ) }, @@ -124,19 +124,19 @@ fn to_graph_rec<'a>( input, extend_original, } => { - let input_schema = &ctx.phys_sm[*input].output_schema; + let input_schema = &ctx.phys_sm[input.node].output_schema; let phys_selectors = selectors .iter() .map(|selector| create_stream_expr(selector, ctx, input_schema)) .collect::>()?; - let input_key = to_graph_rec(*input, ctx)?; + let input_key = to_graph_rec(input.node, ctx)?; ctx.graph.add_node( nodes::select::SelectNode::new( phys_selectors, node.output_schema.clone(), *extend_original, ), - [input_key], + [(input_key, input.port)], ) }, @@ -145,10 +145,10 @@ fn to_graph_rec<'a>( name, offset, } => { - let input_key = to_graph_rec(*input, ctx)?; + let input_key = to_graph_rec(input.node, ctx)?; ctx.graph.add_node( nodes::with_row_index::WithRowIndexNode::new(name.clone(), *offset), - [input_key], + [(input_key, input.port)], ) }, @@ -165,8 +165,8 @@ fn to_graph_rec<'a>( }, Reduce { input, exprs } => { - let input_key = to_graph_rec(*input, ctx)?; - let input_schema = &ctx.phys_sm[*input].output_schema; + let input_key = to_graph_rec(input.node, ctx)?; + let input_schema = &ctx.phys_sm[input.node].output_schema; let mut reductions = Vec::with_capacity(exprs.len()); let mut inputs = Vec::with_capacity(reductions.len()); @@ -186,24 +186,24 @@ fn to_graph_rec<'a>( ctx.graph.add_node( nodes::reduce::ReduceNode::new(inputs, reductions, node.output_schema.clone()), - [input_key], + [(input_key, input.port)], ) }, SimpleProjection { input, columns } => { - let input_schema = ctx.phys_sm[*input].output_schema.clone(); - let input_key = to_graph_rec(*input, ctx)?; + let input_schema = ctx.phys_sm[input.node].output_schema.clone(); + let input_key = to_graph_rec(input.node, ctx)?; ctx.graph.add_node( nodes::simple_projection::SimpleProjectionNode::new(columns.clone(), input_schema), - [input_key], + [(input_key, input.port)], ) }, InMemorySink { input } => { - let input_schema = ctx.phys_sm[*input].output_schema.clone(); - let input_key = to_graph_rec(*input, ctx)?; + let input_schema = ctx.phys_sm[input.node].output_schema.clone(); + let input_key = to_graph_rec(input.node, ctx)?; ctx.graph.add_node( nodes::in_memory_sink::InMemorySinkNode::new(input_schema), - [input_key], + [(input_key, input.port)], ) }, @@ -212,32 +212,34 @@ fn to_graph_rec<'a>( file_type, input, } => { - let input_schema = ctx.phys_sm[*input].output_schema.clone(); - let input_key = to_graph_rec(*input, ctx)?; + let input_schema = ctx.phys_sm[input.node].output_schema.clone(); + let input_key = to_graph_rec(input.node, ctx)?; match file_type { #[cfg(feature = "ipc")] FileType::Ipc(ipc_writer_options) => ctx.graph.add_node( nodes::io_sinks::ipc::IpcSinkNode::new(input_schema, path, ipc_writer_options)?, - [input_key], + [(input_key, input.port)], ), _ => todo!(), } }, InMemoryMap { input, map } => { - let input_schema = ctx.phys_sm[*input].output_schema.clone(); - let input_key = to_graph_rec(*input, ctx)?; + let input_schema = ctx.phys_sm[input.node].output_schema.clone(); + let input_key = to_graph_rec(input.node, ctx)?; ctx.graph.add_node( nodes::in_memory_map::InMemoryMapNode::new(input_schema, map.clone()), - [input_key], + [(input_key, input.port)], ) }, Map { input, map } => { - let input_key = to_graph_rec(*input, ctx)?; - ctx.graph - .add_node(nodes::map::MapNode::new(map.clone()), [input_key]) + let input_key = to_graph_rec(input.node, ctx)?; + ctx.graph.add_node( + nodes::map::MapNode::new(map.clone()), + [(input_key, input.port)], + ) }, Sort { @@ -246,7 +248,7 @@ fn to_graph_rec<'a>( slice, sort_options, } => { - let input_schema = ctx.phys_sm[*input].output_schema.clone(); + let input_schema = ctx.phys_sm[input.node].output_schema.clone(); let lmdf = Arc::new(LateMaterializedDataFrame::default()); let mut lp_arena = Arena::default(); let df_node = lp_arena.add(lmdf.clone().as_ir_node(input_schema.clone())); @@ -262,7 +264,7 @@ fn to_graph_rec<'a>( ctx.expr_arena, )?); - let input_key = to_graph_rec(*input, ctx)?; + let input_key = to_graph_rec(input.node, ctx)?; ctx.graph.add_node( nodes::in_memory_map::InMemoryMapNode::new( input_schema, @@ -272,15 +274,15 @@ fn to_graph_rec<'a>( executor.lock().execute(&mut state) }), ), - [input_key], + [(input_key, input.port)], ) }, OrderedUnion { inputs } => { let input_keys = inputs .iter() - .map(|i| to_graph_rec(*i, ctx)) - .collect::, _>>()?; + .map(|i| PolarsResult::Ok((to_graph_rec(i.node, ctx)?, i.port))) + .try_collect_vec()?; ctx.graph .add_node(nodes::ordered_union::OrderedUnionNode::new(), input_keys) }, @@ -291,11 +293,11 @@ fn to_graph_rec<'a>( } => { let input_schemas = inputs .iter() - .map(|i| ctx.phys_sm[*i].output_schema.clone()) + .map(|i| ctx.phys_sm[i.node].output_schema.clone()) .collect_vec(); let input_keys = inputs .iter() - .map(|i| to_graph_rec(*i, ctx)) + .map(|i| PolarsResult::Ok((to_graph_rec(i.node, ctx)?, i.port))) .try_collect_vec()?; ctx.graph.add_node( nodes::zip::ZipNode::new(*null_extend, input_schemas), @@ -304,9 +306,11 @@ fn to_graph_rec<'a>( }, Multiplexer { input } => { - let input_key = to_graph_rec(*input, ctx)?; - ctx.graph - .add_node(nodes::multiplexer::MultiplexerNode::new(), [input_key]) + let input_key = to_graph_rec(input.node, ctx)?; + ctx.graph.add_node( + nodes::multiplexer::MultiplexerNode::new(), + [(input_key, input.port)], + ) }, v @ FileScan { .. } => { @@ -413,9 +417,9 @@ fn to_graph_rec<'a>( }, GroupBy { input, key, aggs } => { - let input_key = to_graph_rec(*input, ctx)?; + let input_key = to_graph_rec(input.node, ctx)?; - let input_schema = &ctx.phys_sm[*input].output_schema; + let input_schema = &ctx.phys_sm[input.node].output_schema; let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)?; let grouper = new_hash_grouper(key_schema); @@ -447,7 +451,7 @@ fn to_graph_rec<'a>( node.output_schema.clone(), PlRandomState::new(), ), - [input_key], + [(input_key, input.port)], ) }, @@ -459,10 +463,10 @@ fn to_graph_rec<'a>( args, options, } => { - let left_input_key = to_graph_rec(*input_left, ctx)?; - let right_input_key = to_graph_rec(*input_right, ctx)?; - let left_input_schema = ctx.phys_sm[*input_left].output_schema.clone(); - let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone(); + let left_input_key = to_graph_rec(input_left.node, ctx)?; + let right_input_key = to_graph_rec(input_right.node, ctx)?; + let left_input_schema = ctx.phys_sm[input_left.node].output_schema.clone(); + let right_input_schema = ctx.phys_sm[input_right.node].output_schema.clone(); let mut lp_arena = Arena::default(); let left_lmdf = Arc::new(LateMaterializedDataFrame::default()); @@ -504,7 +508,10 @@ fn to_graph_rec<'a>( executor.lock().execute(&mut state) }), ), - [left_input_key, right_input_key], + [ + (left_input_key, input_left.port), + (right_input_key, input_right.port), + ], ) }, @@ -516,10 +523,10 @@ fn to_graph_rec<'a>( args, } => { let args = args.clone(); - let left_input_key = to_graph_rec(*input_left, ctx)?; - let right_input_key = to_graph_rec(*input_right, ctx)?; - let left_input_schema = ctx.phys_sm[*input_left].output_schema.clone(); - let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone(); + let left_input_key = to_graph_rec(input_left.node, ctx)?; + let right_input_key = to_graph_rec(input_right.node, ctx)?; + let left_input_schema = ctx.phys_sm[input_left.node].output_schema.clone(); + let right_input_schema = ctx.phys_sm[input_right.node].output_schema.clone(); let left_key_schema = compute_output_schema(&left_input_schema, left_on, ctx.expr_arena)?; @@ -545,7 +552,10 @@ fn to_graph_rec<'a>( right_key_selectors, args, )?, - [left_input_key, right_input_key], + [ + (left_input_key, input_left.port), + (right_input_key, input_right.port), + ], ) }, }; diff --git a/crates/polars-stream/src/skeleton.rs b/crates/polars-stream/src/skeleton.rs index 97e6a9c73272..77ed1673cf3b 100644 --- a/crates/polars-stream/src/skeleton.rs +++ b/crates/polars-stream/src/skeleton.rs @@ -33,8 +33,10 @@ pub fn run_query( } let (mut graph, phys_to_graph) = crate::physical_plan::physical_plan_to_graph(root, &phys_sm, expr_arena)?; + crate::async_executor::clear_task_wait_statistics(); let mut results = crate::execute::execute_graph(&mut graph)?; + if std::env::var("POLARS_TRACK_WAIT_STATS").as_deref() == Ok("1") { let mut stats = crate::async_executor::get_task_wait_statistics(); stats.sort_by_key(|(_l, w)| Reverse(*w));