diff --git a/sqlparser_bench/benches/sqlparser_bench.rs b/sqlparser_bench/benches/sqlparser_bench.rs index 8654a313f..9517cc1b3 100644 --- a/sqlparser_bench/benches/sqlparser_bench.rs +++ b/sqlparser_bench/benches/sqlparser_bench.rs @@ -273,6 +273,32 @@ fn parse_table_factor_paren_chain(c: &mut Criterion) { group.finish(); } +/// Benchmark parsing pathological nested function-call arguments that +/// previously caused 2^N work in `parse_function_args`. Each positional +/// argument was speculatively parsed as a named-argument name (a full +/// expression), and when no `=>`/`:=` operator followed, rewound and +/// re-parsed as a positional argument -- doubling work at every nesting +/// level. PostgreSQL is used because it enables expression-named arguments. +fn parse_function_call_arg_chain(c: &mut Criterion) { + let mut group = c.benchmark_group("parse_function_call_arg_chain"); + let dialect = PostgreSqlDialect {}; + + for &n in &[10usize, 20, 30] { + let sql = String::from("SELECT ") + &"replace(".repeat(n) + "x" + &",'a','b')".repeat(n); + + group.bench_function(format!("chain_{n}"), |b| { + b.iter(|| { + let _ = Parser::new(&dialect) + .with_recursion_limit(256) + .try_with_sql(std::hint::black_box(&sql)) + .and_then(|mut p| p.parse_statements()); + }); + }); + } + + group.finish(); +} + criterion_group!( benches, basic_queries, @@ -282,6 +308,7 @@ criterion_group!( parse_compound_keyword_chain, parse_prefix_keyword_call_chain, parse_prefix_case_chain, - parse_table_factor_paren_chain + parse_table_factor_paren_chain, + parse_function_call_arg_chain ); criterion_main!(benches); diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 745478300..4b2f80ac2 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -18555,33 +18555,58 @@ impl<'a> Parser<'a> { /// Parse a single function argument, handling named and unnamed variants. pub fn parse_function_args(&mut self) -> Result { - let arg = if self.dialect.supports_named_fn_args_with_expr_name() { - self.maybe_parse(|p| { - let name = p.parse_expr()?; - let operator = p.parse_function_named_arg_operator()?; - let arg = p.parse_wildcard_expr()?.into(); - Ok(FunctionArg::ExprNamed { - name, - arg, - operator, - }) - })? - } else { - self.maybe_parse(|p| { - let name = p.parse_identifier()?; - let operator = p.parse_function_named_arg_operator()?; - let arg = p.parse_wildcard_expr()?.into(); - Ok(FunctionArg::Named { - name, - arg, - operator, - }) - })? - }; + // For dialects where a named-argument name may be an arbitrary + // expression (e.g. MSSQL `JSON_OBJECT`, PostgreSQL), parse the leading + // expression once and then check for a named-argument operator: when one + // follows it is a named argument, otherwise the same expression is the + // positional argument. Parsing once avoids speculatively parsing the + // whole expression and re-parsing it on the positional path, which was + // O(2^n) on deeply nested function-call arguments. + if self.dialect.supports_named_fn_args_with_expr_name() { + let expr = self.parse_wildcard_expr()?; + // A wildcard (`*`, `t.*`) can never be a named-argument name. + if !matches!(expr, Expr::Wildcard(_)) { + if let Some(operator) = + self.maybe_parse(|p| p.parse_function_named_arg_operator())? + { + let arg = self.parse_wildcard_expr()?.into(); + return Ok(FunctionArg::ExprNamed { + name: expr, + arg, + operator, + }); + } + } + return self.parse_unnamed_function_arg(expr); + } + + // Dialects where the name must be a bare identifier: the speculative + // parse only consumes a single token, so re-parsing on the positional + // path is cheap. + let arg = self.maybe_parse(|p| { + let name = p.parse_identifier()?; + let operator = p.parse_function_named_arg_operator()?; + let arg = p.parse_wildcard_expr()?.into(); + Ok(FunctionArg::Named { + name, + arg, + operator, + }) + })?; if let Some(arg) = arg { return Ok(arg); } let wildcard_expr = self.parse_wildcard_expr()?; + self.parse_unnamed_function_arg(wildcard_expr) + } + + /// Build an unnamed [`FunctionArg`] from an already-parsed argument + /// expression, applying wildcard options (`* EXCLUDE (...)`) and aliasing + /// (`expr AS name`) where the dialect supports them. + fn parse_unnamed_function_arg( + &mut self, + wildcard_expr: Expr, + ) -> Result { let arg_expr: FunctionArgExpr = match wildcard_expr { Expr::Wildcard(ref token) if self.dialect.supports_select_wildcard_exclude() => { // Support `* EXCLUDE(col1, col2, ...)` inside function calls (e.g. Snowflake's diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index b48b1b5a0..d23324d84 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -15590,15 +15590,15 @@ fn test_reserved_keywords_for_identifiers() { ); // Dialects with expression-named function arguments parse the argument - // expression twice, so the second attempt reports the memoized failure - // at the start of the expression + // expression once and report the same error: INTERVAL begins an interval + // expression, which then fails on the closing paren. let dialects = all_dialects_where(|d| { d.is_reserved_for_identifier(Keyword::INTERVAL) && d.supports_named_fn_args_with_expr_name() }); assert_eq!( dialects.parse_sql_statements(sql), Err(ParserError::ParserError( - "Expected: an expression, found: interval".to_string() + "Expected: an expression, found: )".to_string() )) ); @@ -19486,3 +19486,31 @@ fn parse_table_factor_paren_chain_no_exponential_blowup() { rx.recv_timeout(Duration::from_secs(5)) .expect("parser should reject this quickly, not loop exponentially"); } + +/// Regression test for the 2^N parse-time blowup in `parse_function_args` on +/// nested function-call arguments like `replace(replace(replace(...x...)))`. +/// On dialects with expression-named arguments, each positional argument was +/// speculatively parsed as a named-argument name (a full expression) and, +/// when no `=>`/`:=` operator followed, rewound and re-parsed as a positional +/// argument -- doubling work at every nesting level. Post-fix the leading +/// expression is parsed once. +#[test] +fn parse_function_call_arg_chain_no_exponential_blowup() { + use std::sync::mpsc; + use std::thread; + use std::time::Duration; + + let sql = String::from("SELECT ") + &"replace(".repeat(30) + "x" + &",'a','b')".repeat(30); + + let (tx, rx) = mpsc::channel(); + thread::spawn(move || { + let _ = Parser::new(&PostgreSqlDialect {}) + .with_recursion_limit(256) + .try_with_sql(&sql) + .and_then(|mut p| p.parse_statements()); + let _ = tx.send(()); + }); + + rx.recv_timeout(Duration::from_secs(5)) + .expect("parser should handle this quickly, not loop exponentially"); +}