diff --git a/src/cst/ProceduralLanguage.ts b/src/cst/ProceduralLanguage.ts index 76524988..91af35ea 100644 --- a/src/cst/ProceduralLanguage.ts +++ b/src/cst/ProceduralLanguage.ts @@ -9,6 +9,7 @@ import { FuncCall, CaseWhen, CaseElse, + Variable, } from "./Expr"; import { StringLiteral } from "./Literal"; import { Program } from "./Program"; @@ -92,7 +93,13 @@ export interface SetStmt extends BaseNode { type: "set_stmt"; setKw: Keyword<"SET">; assignments: ListExpr< - BinaryExpr>, "=", Expr> + BinaryExpr< + | Identifier + | Variable + | ParenExpr | ParenExpr>>, + "=", + Expr + > >; } diff --git a/src/parser.pegjs b/src/parser.pegjs index 448b44d3..0c627f46 100644 --- a/src/parser.pegjs +++ b/src/parser.pegjs @@ -4457,7 +4457,15 @@ set_stmt } set_assignment - = name:((ident / paren$list$ident) __) "=" value:(__ expr) { + = &mysql name:((ident / variable / paren$list$ident / paren$list$variable) __) "=" value:(__ expr) { + return loc({ + type: "binary_expr", + left: read(name), + operator: "=", + right: read(value), + }) + } + / !mysql name:((ident / paren$list$ident) __) "=" value:(__ expr) { return loc({ type: "binary_expr", left: read(name), @@ -6852,6 +6860,7 @@ paren$list$string_literal = . paren$list$table_func_call = . paren$list$table_option_postgresql = . paren$list$tablesample_arg = . +paren$list$variable = . paren$list$view_column_definition = . paren$pivot_for_in = . paren$postgresql_op = . diff --git a/test/proc/set.test.ts b/test/proc/set.test.ts index 21603c13..f7231bf2 100644 --- a/test/proc/set.test.ts +++ b/test/proc/set.test.ts @@ -19,6 +19,20 @@ describe("SET", () => { }); }); + dialect(["mysql", "mariadb"], () => { + it("supports SET statement", () => { + testWc("SET @x = 10"); + }); + + it("supports multiple assignments", () => { + testWc("SET @x = 1, @y = 'foo', @z = false"); + }); + + it("supports scalar subquery", () => { + testWc("SET @sum_age = (SELECT SUM(age) FROM user)"); + }); + }); + dialect("sqlite", () => { it("does not support SET statement", () => { expect(() => parse("SET x = 1")).toThrowError();