From 73891ee0397c0c6c46e78256c74129aae4bc663c Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 14 Nov 2024 20:48:02 +0000 Subject: [PATCH 1/7] Add `#[recursive]` --- Cargo.toml | 4 +++- derive/src/lib.rs | 1 + src/ast/mod.rs | 1 + src/ast/visitor.rs | 26 ++++++++++++++++++++++++++ tests/sqlparser_common.rs | 10 ++++++++++ 5 files changed, 41 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 18b246e04..e5f6efbc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ path = "src/lib.rs" [features] default = ["std"] -std = [] +std = ["recursive"] # Enable JSON output in the `cli` example: json_example = ["serde_json", "serde"] visitor = ["sqlparser_derive"] @@ -46,6 +46,8 @@ visitor = ["sqlparser_derive"] [dependencies] bigdecimal = { version = "0.4.1", features = ["serde"], optional = true } log = "0.4" +recursive = { version = "0.1.1", optional = true} + serde = { version = "1.0", features = ["derive"], optional = true } # serde_json is only used in examples/cli, but we have to put it outside # of dev-dependencies because of diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 5ad1607f9..ffa56a533 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -78,6 +78,7 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_ let expanded = quote! { // The generated impl. impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause { + #[cfg_attr(feature = "std", recursive::recursive)] fn visit( &#modifier self, visitor: &mut V diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 505386fbf..6c77470e3 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1188,6 +1188,7 @@ impl fmt::Display for CastFormat { } impl fmt::Display for Expr { + #[cfg_attr(feature = "std", recursive::recursive)] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Expr::Identifier(s) => write!(f, "{s}"), diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 418e0a299..5dbae504b 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -884,4 +884,30 @@ mod tests { assert_eq!(actual, expected) } } + + + struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes + + impl Visitor for QuickVisitor { + type Break = (); + } + + #[test] + fn overflow() { + let cond = (0..1000) + .map(|n| format!("X = {}", n)) + .collect::>() + .join(" OR "); + let sql = format!("SELECT x where {0}", cond); + + let dialect = GenericDialect {}; + let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap(); + let s = Parser::new(&dialect) + .with_tokens(tokens) + .parse_statement() + .unwrap(); + + let mut visitor = QuickVisitor {} ; + s.visit(&mut visitor); + } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index daf65edf1..14481d477 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11748,3 +11748,13 @@ fn parse_create_table_select() { ); } } + +#[test] +fn overflow() { + let expr = std::iter::repeat("1").take(1000).collect::>().join(" + "); + let sql = format!("SELECT {}", expr); + + let mut statements = Parser::parse_sql(&GenericDialect {}, sql.as_str()).unwrap(); + let statement = statements.pop().unwrap(); + assert_eq!(statement.to_string(), sql); +} From 39f710db84b12628b889a9e1a5dbdcedf543ea1f Mon Sep 17 00:00:00 2001 From: blaginin Date: Fri, 15 Nov 2024 19:01:54 +0000 Subject: [PATCH 2/7] Add larger benchmarks --- sqlparser_bench/benches/sqlparser_bench.rs | 40 ++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/sqlparser_bench/benches/sqlparser_bench.rs b/sqlparser_bench/benches/sqlparser_bench.rs index 27c58b450..52716e97d 100644 --- a/sqlparser_bench/benches/sqlparser_bench.rs +++ b/sqlparser_bench/benches/sqlparser_bench.rs @@ -42,6 +42,46 @@ fn basic_queries(c: &mut Criterion) { group.bench_function("sqlparser::with_select", |b| { b.iter(|| Parser::parse_sql(&dialect, with_query)); }); + + let complex_sql = { + let expressions = (0..1000) + .map(|n| format!("FN_{}(COL_{})", n, n)) + .collect::>() + .join(", "); + let tables = (0..1000) + .map(|n| format!("TABLE_{}", n)) + .collect::>() + .join(" JOIN "); + let where_condition = (0..1000) + .map(|n| format!("COL_{} = {}", n, n)) + .collect::>() + .join(" OR "); + let order_condition = (0..1000) + .map(|n| format!("COL_{} DESC", n)) + .collect::>() + .join(", "); + + format!( + "SELECT {} FROM {} WHERE {} ORDER BY {}", + expressions, tables, where_condition, order_condition + ) + }; + + group.bench_function("parse_large_query", |b| { + b.iter(|| Parser::parse_sql(&dialect, criterion::black_box(complex_sql.as_str()))); + }); + + let complex_query = Parser::parse_sql(&dialect, complex_sql.as_str()) + .unwrap() + .pop() + .unwrap(); + + group.bench_function("format_large_query", |b| { + b.iter(|| { + let formatted_query = complex_query.to_string(); + assert_eq!(formatted_query, complex_sql); + }); + }); } criterion_group!(benches, basic_queries); From 90843ad551511de19d9c186b1f83226778edd084 Mon Sep 17 00:00:00 2001 From: blaginin Date: Fri, 15 Nov 2024 19:10:35 +0000 Subject: [PATCH 3/7] Rename --- sqlparser_bench/benches/sqlparser_bench.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sqlparser_bench/benches/sqlparser_bench.rs b/sqlparser_bench/benches/sqlparser_bench.rs index 52716e97d..afd53e117 100644 --- a/sqlparser_bench/benches/sqlparser_bench.rs +++ b/sqlparser_bench/benches/sqlparser_bench.rs @@ -43,7 +43,7 @@ fn basic_queries(c: &mut Criterion) { b.iter(|| Parser::parse_sql(&dialect, with_query)); }); - let complex_sql = { + let large_statement = { let expressions = (0..1000) .map(|n| format!("FN_{}(COL_{})", n, n)) .collect::>() @@ -67,19 +67,19 @@ fn basic_queries(c: &mut Criterion) { ) }; - group.bench_function("parse_large_query", |b| { - b.iter(|| Parser::parse_sql(&dialect, criterion::black_box(complex_sql.as_str()))); + group.bench_function("parse_large_statement", |b| { + b.iter(|| Parser::parse_sql(&dialect, criterion::black_box(large_statement.as_str()))); }); - let complex_query = Parser::parse_sql(&dialect, complex_sql.as_str()) + let large_statement = Parser::parse_sql(&dialect, large_statement.as_str()) .unwrap() .pop() .unwrap(); - group.bench_function("format_large_query", |b| { + group.bench_function("format_large_statement", |b| { b.iter(|| { - let formatted_query = complex_query.to_string(); - assert_eq!(formatted_query, complex_sql); + let formatted_query = large_statement.to_string(); + assert_eq!(formatted_query, large_statement); }); }); } From 3af997b42001a6462ffebdeef231b152f286cb84 Mon Sep 17 00:00:00 2001 From: blaginin Date: Fri, 15 Nov 2024 19:12:01 +0000 Subject: [PATCH 4/7] Cargo fmt --- src/ast/visitor.rs | 3 +-- tests/sqlparser_common.rs | 5 ++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 5dbae504b..39e701f1e 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -885,7 +885,6 @@ mod tests { } } - struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes impl Visitor for QuickVisitor { @@ -907,7 +906,7 @@ mod tests { .parse_statement() .unwrap(); - let mut visitor = QuickVisitor {} ; + let mut visitor = QuickVisitor {}; s.visit(&mut visitor); } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 14481d477..d756ec447 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11751,7 +11751,10 @@ fn parse_create_table_select() { #[test] fn overflow() { - let expr = std::iter::repeat("1").take(1000).collect::>().join(" + "); + let expr = std::iter::repeat("1") + .take(1000) + .collect::>() + .join(" + "); let sql = format!("SELECT {}", expr); let mut statements = Parser::parse_sql(&GenericDialect {}, sql.as_str()).unwrap(); From 82b26f64bde43e1bca15acd598deea6942507d66 Mon Sep 17 00:00:00 2001 From: blaginin Date: Tue, 26 Nov 2024 19:11:45 +0000 Subject: [PATCH 5/7] Cargo fmt --- tests/sqlparser_common.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index c3cc8a5fd..2c5afd434 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -12441,7 +12441,6 @@ fn test_reserved_keywords_for_identifiers() { dialects.parse_sql_statements(sql).unwrap(); } - #[test] fn overflow() { let expr = std::iter::repeat("1") From c9eef87893479b8fd993f42870cedeedc935b508 Mon Sep 17 00:00:00 2001 From: blaginin Date: Tue, 26 Nov 2024 19:09:32 +0000 Subject: [PATCH 6/7] Add notes --- derive/src/lib.rs | 2 ++ src/parser/mod.rs | 3 +++ 2 files changed, 5 insertions(+) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index ffa56a533..067c3778a 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -77,6 +77,8 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_ let expanded = quote! { // The generated impl. + // Note that it uses [`recursive::recursive`] to protect from stack overflow. + // See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info. impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause { #[cfg_attr(feature = "std", recursive::recursive)] fn visit( diff --git a/src/parser/mod.rs b/src/parser/mod.rs index b7f5cb866..6ed0fbfcd 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -73,6 +73,9 @@ mod recursion { /// Note: Uses an [`std::rc::Rc`] and [`std::cell::Cell`] in order to satisfy the Rust /// borrow checker so the automatic [`DepthGuard`] decrement a /// reference to the counter. + /// + /// Note: when "std" feature is enabled, this crate uses additional stack overflow protection + /// for some of its recursive methods. See [`recursive::recursive`] for more information. pub(crate) struct RecursionCounter { remaining_depth: Rc>, } From 913a291b2d3a8a6ea7d24c5f4c350a82a3cfa242 Mon Sep 17 00:00:00 2001 From: blaginin Date: Tue, 26 Nov 2024 19:15:49 +0000 Subject: [PATCH 7/7] Add a note on `with_recursion_limit` --- src/parser/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 6ed0fbfcd..bc2e68ef6 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -329,6 +329,9 @@ impl<'a> Parser<'a> { /// # Ok(()) /// # } /// ``` + /// + /// Note: when "std" feature is enabled, this crate uses additional stack overflow protection + // for some of its recursive methods. See [`recursive::recursive`] for more information. pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self { self.recursion_counter = RecursionCounter::new(recursion_limit); self