diff --git a/src/ast/mod.rs b/src/ast/mod.rs index d4808d411..e8f15a5a9 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1779,12 +1779,13 @@ impl fmt::Display for Statement { .iter() .map(|key_value| { format!( - "{}{}", + "{}{} = {}", if key_value.local { "LOCAL " } else { "" }, + key_value.key, key_value .value .iter() - .map(|value| format!("{} = {}", key_value.key, value)) + .map(|value| format!("{}", value)) .collect::>() .join(", ") ) diff --git a/src/parser.rs b/src/parser.rs index 9df2b7f65..0a179a23b 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -3684,11 +3684,16 @@ impl<'a> Parser<'a> { } let mut key_values: Vec = vec![]; - loop { + + if dialect_of!(self is PostgreSqlDialect | RedshiftSqlDialect) { let variable = self.parse_identifier()?; let mut values = vec![]; - if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { + if !self.consume_token(&Token::Eq) && !self.parse_keyword(Keyword::TO) { + return self.expected("equals sign or TO", self.peek_token()); + } + + loop { let value = if let Ok(expr) = self.parse_expr() { expr } else { @@ -3697,22 +3702,50 @@ impl<'a> Parser<'a> { values.push(value); - key_values.push(SetVariableKeyValue { - key: variable, - value: values, - local: modifier == Some(Keyword::LOCAL), - hivevar: false, - }); - - if self.consume_token(&Token::Comma) { - continue; + if !self.consume_token(&Token::Comma) { + break; } + } - return Ok(Statement::SetVariable { key_values }); - } else { + key_values.push(SetVariableKeyValue { + key: variable, + value: values, + local: modifier == Some(Keyword::LOCAL), + hivevar: false, + }); + + return Ok(Statement::SetVariable { key_values }); + } + + loop { + let variable = self.parse_identifier()?; + let mut values = vec![]; + + if !self.consume_token(&Token::Eq) && !self.parse_keyword(Keyword::TO) { return self.expected("equals sign or TO", self.peek_token()); } + + let value = if let Ok(expr) = self.parse_expr() { + expr + } else { + self.expected("variable value", self.peek_token())? + }; + + values.push(value); + + key_values.push(SetVariableKeyValue { + key: variable, + value: values, + local: modifier == Some(Keyword::LOCAL), + hivevar: false, + }); + + if !self.consume_token(&Token::Comma) { + break; + } } + + Ok(Statement::SetVariable { key_values }) } pub fn parse_show(&mut self) -> Result { diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index c10dfef76..ac8d283f9 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -790,6 +790,29 @@ fn parse_set() { } ); + let stmt = pg().verified_stmt("SET a = b, c"); + assert_eq!( + stmt, + Statement::SetVariable { + key_values: [SetVariableKeyValue { + key: "a".into(), + value: vec![ + Expr::Identifier(Ident { + value: "b".into(), + quote_style: None + }), + Expr::Identifier(Ident { + value: "c".into(), + quote_style: None + }), + ], + local: false, + hivevar: false, + }] + .to_vec() + } + ); + let stmt = pg_and_generic().verified_stmt("SET a = 'b'"); assert_eq!( stmt,