diff --git a/pkg/runtime/ast/ast_test.go b/pkg/runtime/ast/ast_test.go index 1827f299..aa70bf9b 100644 --- a/pkg/runtime/ast/ast_test.go +++ b/pkg/runtime/ast/ast_test.go @@ -45,11 +45,13 @@ func TestParse(t *testing.T) { "SELECT (2021 - birth_year) as AGE, count(1) as amount from student where uid between 1 and 10 group by (2021-birth_year)", "select * from student where uid = !0", } { - t.Run(sql, func(t *testing.T) { - _, stmt, err = Parse(sql) - assert.NoError(t, err) - t.Log("stmt:", stmt) - }) + t.Run( + sql, func(t *testing.T) { + _, stmt, err = Parse(sql) + assert.NoError(t, err) + t.Log("stmt:", stmt) + }, + ) } // 1. select statement @@ -95,17 +97,22 @@ func TestParse_UnionStmt(t *testing.T) { {"select 1 union select 2", "SELECT 1 UNION SELECT 2"}, {"select 1 union distinct select 2", "SELECT 1 UNION SELECT 2"}, {"select 1 union all select 2", "SELECT 1 UNION ALL SELECT 2"}, - {"select id,uid,name,nickname from student where uid in (?,?,?) union all select id,uid,name,nickname from tb_user where uid in (?,?,?)", "SELECT `id`,`uid`,`name`,`nickname` FROM `student` WHERE `uid` IN (?,?,?) UNION ALL SELECT `id`,`uid`,`name`,`nickname` FROM `tb_user` WHERE `uid` IN (?,?,?)"}, + { + "select id,uid,name,nickname from student where uid in (?,?,?) union all select id,uid,name,nickname from tb_user where uid in (?,?,?)", + "SELECT `id`,`uid`,`name`,`nickname` FROM `student` WHERE `uid` IN (?,?,?) UNION ALL SELECT `id`,`uid`,`name`,`nickname` FROM `tb_user` WHERE `uid` IN (?,?,?)", + }, } { - t.Run(next.input, func(t *testing.T) { - _, stmt, err := Parse(next.input) - assert.NoError(t, err, "should parse ok") - assert.IsType(t, (*UnionSelectStatement)(nil), stmt, "should be union statement") - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, next.expect, actual) - }) + t.Run( + next.input, func(t *testing.T) { + _, stmt, err := Parse(next.input) + assert.NoError(t, err, "should parse ok") + assert.IsType(t, (*UnionSelectStatement)(nil), stmt, "should be union statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, next.expect, actual) + }, + ) } } @@ -117,16 +124,28 @@ func TestParse_SelectStmt(t *testing.T) { for _, next := range []tt{ {"select * from a left join b on a.k = b.k", "SELECT * FROM `a` LEFT JOIN `b` ON `a`.`k` = `b`.`k`"}, - {"select * from foo as a left join bar as b on a.k = b.k", "SELECT * FROM `foo` AS `a` LEFT JOIN `bar` AS `b` ON `a`.`k` = `b`.`k`"}, + { + "select * from foo as a left join bar as b on a.k = b.k", + "SELECT * FROM `foo` AS `a` LEFT JOIN `bar` AS `b` ON `a`.`k` = `b`.`k`", + }, {"select @@version", "SELECT @@`version`"}, {"select * from student for update", "SELECT * FROM `student` FOR UPDATE"}, {"select connection_id()", "SELECT CONNECTION_ID()"}, - {`SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user`, "SELECT CONCAT('\\'',`user`,'\\'@\\'',`host`,'\\'') FROM `mysql`.`user`"}, + { + `SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user`, + "SELECT CONCAT('\\'',`user`,'\\'@\\'',`host`,'\\'') FROM `mysql`.`user`", + }, {"select * from student where uid = abs(-11)", "SELECT * FROM `student` WHERE `uid` = ABS(-11)"}, {"select * from student where uid = 1 limit 3 offset ?", "SELECT * FROM `student` WHERE `uid` = 1 LIMIT ?,3"}, //{"select case count(*) when 0 then -3.14 else 2.17 end as xxx from student where uid in (-1,-2,-3)", "SELECT CASE COUNT(*) WHEN 0 THEN -3.14 ELSE 2.17 END AS `xxx` FROM `student` WHERE `uid` IN (-1,-2,-3)"}, - {"select * from tb_user a where (uid >= ? AND uid <= ?)", "SELECT * FROM `tb_user` AS `a` WHERE (`uid` >= ? AND `uid` <= ?)"}, - {"SELECT (2021 - birth_year) as AGE, count(1) as amount from student where uid between 1 and 10 group by (2021-birth_year)", "SELECT (2021-`birth_year`) AS `AGE`,COUNT(1) AS `amount` FROM `student` WHERE `uid` BETWEEN 1 AND 10 GROUP BY (2021-`birth_year`)"}, + { + "select * from tb_user a where (uid >= ? AND uid <= ?)", + "SELECT * FROM `tb_user` AS `a` WHERE (`uid` >= ? AND `uid` <= ?)", + }, + { + "SELECT (2021 - birth_year) as AGE, count(1) as amount from student where uid between 1 and 10 group by (2021-birth_year)", + "SELECT (2021-`birth_year`) AS `AGE`,COUNT(1) AS `amount` FROM `student` WHERE `uid` BETWEEN 1 AND 10 GROUP BY (2021-`birth_year`)", + }, {"select * from student where uid = !0", "SELECT * FROM `student` WHERE `uid` = !0"}, {"select convert(col using 'utf8')", "SELECT CONVERT(`col` USING utf8)"}, {"select convert('foo' using utf8mb4)", "SELECT CONVERT('foo' USING utf8mb4)"}, @@ -135,39 +154,95 @@ func TestParse_SelectStmt(t *testing.T) { {"select cast(3.14 as decimal(6,2))", "SELECT CAST(3.14 AS DECIMAL(6,2))"}, {"select cast(3.14 as char(6))", "SELECT CAST(3.14 AS CHAR(6))"}, //{"select cast('foo' as nchar(1))", "SELECT CAST('foo' AS NCHAR(1))"}, - {"select * from student force index(uk_uid) where uid in (1,2,3)", "SELECT * FROM `student` FORCE INDEX(`uk_uid`) WHERE `uid` IN (1,2,3)"}, - {"select * from student PARTITION (foo,bar) as foobar", "SELECT * FROM `student` PARTITION (`foo`,`bar`) AS `foobar`"}, - {"select IF(sum(gender),1,0)+1 as xy from tb_user where uid in (7777, 10099) or uid between 10000 and 10004", "SELECT IF(SUM(`gender`),1,0)+1 AS `xy` FROM `tb_user` WHERE `uid` IN (7777,10099) OR `uid` BETWEEN 10000 AND 10004"}, - {"select * from tb_user where uid is not null and uid = 10001", "SELECT * FROM `tb_user` WHERE `uid` IS NOT NULL AND `uid` = 10001"}, - {"select * from student where uid = case when 2>1 then ? end", "SELECT * FROM `student` WHERE `uid` = CASE WHEN 2 > 1 THEN ? END"}, - {"select * from student where uid = case when 2<>2 then ? end", "SELECT * FROM `student` WHERE `uid` = CASE WHEN 2 <> 2 THEN ? END"}, - {"select * from student where uid = case when 1=2 then 1 else ? end", "SELECT * FROM `student` WHERE `uid` = CASE WHEN 1 = 2 THEN 1 ELSE ? END"}, - {"select * from student where uid = case when 1=2 then 1 when 1=1 then 33 else 31 end", "SELECT * FROM `student` WHERE `uid` = CASE WHEN 1 = 2 THEN 1 WHEN 1 = 1 THEN 33 ELSE 31 END"}, - {"select * from student where uid = ABS(case when IF(1=2,true,false) then 1 else ? end)", "SELECT * FROM `student` WHERE `uid` = ABS(CASE WHEN IF(1 = 2,1,0) THEN 1 ELSE ? END)"}, // FIXME: use true/false instead of 1/0 - {"select * from student where uid = ABS(1-1+(case when IF(1=?,2,1)-1 then 1 else ? end))", "SELECT * FROM `student` WHERE `uid` = ABS(1-1+(CASE WHEN IF(1 = ?,2,1)-1 THEN 1 ELSE ? END))"}, - {"select * from student where uid = case (4%5) when 1 then 1 when 4 then ? else 0 end", "SELECT * FROM `student` WHERE `uid` = CASE (4%5) WHEN 1 THEN 1 WHEN 4 THEN ? ELSE 0 END"}, + { + "select * from student force index(uk_uid) where uid in (1,2,3)", + "SELECT * FROM `student` FORCE INDEX(`uk_uid`) WHERE `uid` IN (1,2,3)", + }, + { + "select * from student PARTITION (foo,bar) as foobar", + "SELECT * FROM `student` PARTITION (`foo`,`bar`) AS `foobar`", + }, + { + "select IF(sum(gender),1,0)+1 as xy from tb_user where uid in (7777, 10099) or uid between 10000 and 10004", + "SELECT IF(SUM(`gender`),1,0)+1 AS `xy` FROM `tb_user` WHERE `uid` IN (7777,10099) OR `uid` BETWEEN 10000 AND 10004", + }, + { + "select * from tb_user where uid is not null and uid = 10001", + "SELECT * FROM `tb_user` WHERE `uid` IS NOT NULL AND `uid` = 10001", + }, + { + "select * from student where uid = case when 2>1 then ? end", + "SELECT * FROM `student` WHERE `uid` = CASE WHEN 2 > 1 THEN ? END", + }, + { + "select * from student where uid = case when 2<>2 then ? end", + "SELECT * FROM `student` WHERE `uid` = CASE WHEN 2 <> 2 THEN ? END", + }, + { + "select * from student where uid = case when 1=2 then 1 else ? end", + "SELECT * FROM `student` WHERE `uid` = CASE WHEN 1 = 2 THEN 1 ELSE ? END", + }, + { + "select * from student where uid = case when 1=2 then 1 when 1=1 then 33 else 31 end", + "SELECT * FROM `student` WHERE `uid` = CASE WHEN 1 = 2 THEN 1 WHEN 1 = 1 THEN 33 ELSE 31 END", + }, + { + "select * from student where uid = ABS(case when IF(1=2,true,false) then 1 else ? end)", + "SELECT * FROM `student` WHERE `uid` = ABS(CASE WHEN IF(1 = 2,1,0) THEN 1 ELSE ? END)", + }, // FIXME: use true/false instead of 1/0 + { + "select * from student where uid = ABS(1-1+(case when IF(1=?,2,1)-1 then 1 else ? end))", + "SELECT * FROM `student` WHERE `uid` = ABS(1-1+(CASE WHEN IF(1 = ?,2,1)-1 THEN 1 ELSE ? END))", + }, + { + "select * from student where uid = case (4%5) when 1 then 1 when 4 then ? else 0 end", + "SELECT * FROM `student` WHERE `uid` = CASE (4%5) WHEN 1 THEN 1 WHEN 4 THEN ? ELSE 0 END", + }, //{"select birth_year,gender,count(*) as cnt from student where uid between 1 and 100 group by birth_year,gender having count(*)>5", "SELECT `birth_year`,`gender`,COUNT(*) AS `cnt` FROM `student` WHERE `uid` BETWEEN 1 AND 100 GROUP BY `birth_year`,`gender` HAVING COUNT(*) > 5"}, - {`select * from (select id,uid from student where uid in(1,?,?)) as aaa`, "SELECT * FROM (SELECT `id`,`uid` FROM `student` WHERE `uid` IN (1,?,?)) AS `aaa`"}, + { + `select * from (select id,uid from student where uid in(1,?,?)) as aaa`, + "SELECT * FROM (SELECT `id`,`uid` FROM `student` WHERE `uid` IN (1,?,?)) AS `aaa`", + }, //{"select count(*) from student where aaa.uid = 1", "SELECT COUNT(*) FROM `student` WHERE `aaa`.`uid` = 1"}, - {`select * from (select id,uid from student where uid in(1,2,3) union all select id,uid from student where uid in (?,?)) as aaa where aaa.uid=?`, "SELECT * FROM (SELECT `id`,`uid` FROM `student` WHERE `uid` IN (1,2,3) UNION ALL SELECT `id`,`uid` FROM `student` WHERE `uid` IN (?,?)) AS `aaa` WHERE `aaa`.`uid` = ?"}, + { + `select * from (select id,uid from student where uid in(1,2,3) union all select id,uid from student where uid in (?,?)) as aaa where aaa.uid=?`, + "SELECT * FROM (SELECT `id`,`uid` FROM `student` WHERE `uid` IN (1,2,3) UNION ALL SELECT `id`,`uid` FROM `student` WHERE `uid` IN (?,?)) AS `aaa` WHERE `aaa`.`uid` = ?", + }, {"select * from student where not uid = 1", "SELECT * FROM `student` WHERE not `uid` = 1"}, - {"select * from student where name not regexp '^Ch+'", "SELECT * FROM `student` WHERE `name` NOT REGEXP '^Ch+'"}, + { + "select * from student where name not regexp '^Ch+'", + "SELECT * FROM `student` WHERE `name` NOT REGEXP '^Ch+'", + }, {"select date_add(NOW(), interval 1 hour)", "SELECT DATE_ADD(NOW(),INTERVAL 1 HOUR)"}, - {"select distinct gender from student where uid in (1,2,3,4)", "SELECT DISTINCT `gender` FROM `student` WHERE `uid` IN (1,2,3,4)"}, - {"select distinct(gender) from student where uid in (1,2,3,4)", "SELECT DISTINCT (`gender`) FROM `student` WHERE `uid` IN (1,2,3,4)"}, - {"select * from foo inner join bar on foo.x = bar.y", "SELECT * FROM `foo` INNER JOIN `bar` ON `foo`.`x` = `bar`.`y`"}, - {"select * from foo left outer join bar on foo.x = bar.y", "SELECT * FROM `foo` LEFT JOIN `bar` ON `foo`.`x` = `bar`.`y`"}, + { + "select distinct gender from student where uid in (1,2,3,4)", + "SELECT DISTINCT `gender` FROM `student` WHERE `uid` IN (1,2,3,4)", + }, + { + "select distinct(gender) from student where uid in (1,2,3,4)", + "SELECT DISTINCT (`gender`) FROM `student` WHERE `uid` IN (1,2,3,4)", + }, + { + "select * from foo inner join bar on foo.x = bar.y", + "SELECT * FROM `foo` INNER JOIN `bar` ON `foo`.`x` = `bar`.`y`", + }, + { + "select * from foo left outer join bar on foo.x = bar.y", + "SELECT * FROM `foo` LEFT JOIN `bar` ON `foo`.`x` = `bar`.`y`", + }, {"select null as pkid", "SELECT NULL AS `pkid`"}, } { - t.Run(next.input, func(t *testing.T) { - _, stmt, err := Parse(next.input) - assert.NoError(t, err, "should parse ok") - assert.IsType(t, (*SelectStatement)(nil), stmt, "should be select statement") - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, next.expect, actual) - }) + t.Run( + next.input, func(t *testing.T) { + _, stmt, err := Parse(next.input) + assert.NoError(t, err, "should parse ok") + assert.IsType(t, (*SelectStatement)(nil), stmt, "should be select statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, next.expect, actual) + }, + ) } } @@ -179,17 +254,22 @@ func TestParse_DeleteStmt(t *testing.T) { for _, it := range []tt{ {"delete from student where id = 1 limit 1", "DELETE FROM `student` WHERE `id` = 1 LIMIT 1"}, - {"delete low_priority quick ignore from student where id = 1", "DELETE LOW_PRIORITY QUICK IGNORE FROM `student` WHERE `id` = 1"}, + { + "delete low_priority quick ignore from student where id = 1", + "DELETE LOW_PRIORITY QUICK IGNORE FROM `student` WHERE `id` = 1", + }, } { - t.Run(it.input, func(t *testing.T) { - _, stmt, err := Parse(it.input) - assert.NoError(t, err) - assert.IsType(t, (*DeleteStatement)(nil), stmt, "should be delete statement") - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, it.expect, actual) - }) + t.Run( + it.input, func(t *testing.T) { + _, stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsType(t, (*DeleteStatement)(nil), stmt, "should be delete statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }, + ) } } @@ -202,15 +282,17 @@ func TestParse_DescribeStatement(t *testing.T) { for _, it := range []tt{ {"desc foobar", "DESC `foobar`"}, } { - t.Run(it.input, func(t *testing.T) { - _, stmt, err := Parse(it.input) - assert.NoError(t, err) - assert.IsType(t, (*DescribeStatement)(nil), stmt, "should be describe statement") - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, it.expect, actual) - }) + t.Run( + it.input, func(t *testing.T) { + _, stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsType(t, (*DescribeStatement)(nil), stmt, "should be describe statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }, + ) } } @@ -242,21 +324,26 @@ func TestParse_ShowStatement(t *testing.T) { {"show table status from foo", (*ShowTableStatus)(nil), "SHOW TABLE STATUS FROM `foo`"}, {"show table status in foo", (*ShowTableStatus)(nil), "SHOW TABLE STATUS FROM `foo`"}, {"show table status in foo like '%bar%'", (*ShowTableStatus)(nil), "SHOW TABLE STATUS FROM `foo` LIKE '%bar%'"}, - {"show table status from foo where name='bar'", (*ShowTableStatus)(nil), "SHOW TABLE STATUS FROM `foo` WHERE `name` = 'bar'"}, + { + "show table status from foo where name='bar'", (*ShowTableStatus)(nil), + "SHOW TABLE STATUS FROM `foo` WHERE `name` = 'bar'", + }, {"show nodes from arana", (*ShowNodes)(nil), "SHOW NODES FROM `arana`"}, {"show users from arana", (*ShowUsers)(nil), "SHOW USERS FROM `arana`"}, {"show sharding table from employees", (*ShowShardingTable)(nil), "employees"}, {"show create sequence arana", (*ShowCreateSequence)(nil), "SHOW CREATE SEQUENCE `arana`"}, } { - t.Run(it.input, func(t *testing.T) { - _, stmt, err := Parse(it.input) - assert.NoError(t, err) - assert.IsTypef(t, it.expectTyp, stmt, "should be %T", it.expectTyp) - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, it.expect, actual) - }) + t.Run( + it.input, func(t *testing.T) { + _, stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsTypef(t, it.expectTyp, stmt, "should be %T", it.expectTyp) + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }, + ) } } @@ -296,10 +383,12 @@ func TestParseMore(t *testing.T) { } for _, sql := range tbls { - t.Run(sql, func(t *testing.T) { - _, _, err := Parse(sql) - assert.NoError(t, err) - }) + t.Run( + sql, func(t *testing.T) { + _, _, err := Parse(sql) + assert.NoError(t, err) + }, + ) } } @@ -310,18 +399,26 @@ func TestParse_UpdateStmt(t *testing.T) { } for _, it := range []tt{ - {"update `student` set version=version+1,modified_at=NOW() where id = 1", "UPDATE `student` SET `version` = `version`+1, `modified_at` = NOW() WHERE `id` = 1"}, - {"update low_priority student set nickname = ? where id = 1 limit 1", "UPDATE LOW_PRIORITY `student` SET `nickname` = ? WHERE `id` = 1 LIMIT 1"}, + { + "update `student` set version=version+1,modified_at=NOW() where id = 1", + "UPDATE `student` SET `version` = `version`+1, `modified_at` = NOW() WHERE `id` = 1", + }, + { + "update low_priority student set nickname = ? where id = 1 limit 1", + "UPDATE LOW_PRIORITY `student` SET `nickname` = ? WHERE `id` = 1 LIMIT 1", + }, } { - t.Run(it.input, func(t *testing.T) { - _, stmt, err := Parse(it.input) - assert.NoError(t, err) - assert.IsTypef(t, (*UpdateStatement)(nil), stmt, "should be update statement") - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, it.expect, actual) - }) + t.Run( + it.input, func(t *testing.T) { + _, stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsTypef(t, (*UpdateStatement)(nil), stmt, "should be update statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }, + ) } } @@ -346,15 +443,17 @@ func TestParse_InsertStmt(t *testing.T) { "INSERT INTO `student`(`id`, `name`) VALUES (1, 'foo'),(2, 'bar') ON DUPLICATE KEY UPDATE `version` = `version`+1, `modified_at` = NOW()", }, } { - t.Run(it.input, func(t *testing.T) { - _, stmt, err := Parse(it.input) - assert.NoError(t, err) - assert.IsTypef(t, (*InsertStatement)(nil), stmt, "should be insert statement") - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, it.expect, actual) - }) + t.Run( + it.input, func(t *testing.T) { + _, stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsTypef(t, (*InsertStatement)(nil), stmt, "should be insert statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }, + ) } for _, it := range []tt{ @@ -375,15 +474,17 @@ func TestParse_InsertStmt(t *testing.T) { "INSERT INTO `student` SELECT `id`,`score` FROM `student_tmp` UNION SELECT `id`*10,`score`*10 FROM `student_tmp`", }, } { - t.Run(it.input, func(t *testing.T) { - _, stmt, err := Parse(it.input) - assert.NoError(t, err) - assert.IsTypef(t, (*InsertSelectStatement)(nil), stmt, "should be insert-select statement") - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, it.expect, actual) - }) + t.Run( + it.input, func(t *testing.T) { + _, stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsTypef(t, (*InsertSelectStatement)(nil), stmt, "should be insert-select statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }, + ) } } @@ -440,15 +541,17 @@ func TestParse_AlterTableStmt(t *testing.T) { "ALTER TABLE `student` RENAME COLUMN `name` TO `nickname`, RENAME COLUMN `nickname` TO `name`", }, } { - t.Run(it.input, func(t *testing.T) { - _, stmt, err := Parse(it.input) - assert.NoError(t, err) - assert.IsTypef(t, (*AlterTableStatement)(nil), stmt, "should be alter table statement") - - actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) - assert.NoError(t, err, "should restore ok") - assert.Equal(t, it.expect, actual) - }) + t.Run( + it.input, func(t *testing.T) { + _, stmt, err := Parse(it.input) + assert.NoError(t, err) + assert.IsTypef(t, (*AlterTableStatement)(nil), stmt, "should be alter table statement") + + actual, err := RestoreToString(RestoreDefault, stmt.(Restorer)) + assert.NoError(t, err, "should restore ok") + assert.Equal(t, it.expect, actual) + }, + ) } } @@ -459,7 +562,6 @@ func TestParse_DescStmt(t *testing.T) { desc := stmt.(*DescribeStatement) var sb strings.Builder _ = desc.Restore(RestoreDefault, &sb, nil) - t.Logf(sb.String()) assert.Equal(t, "DESC `student` `id`", sb.String()) } @@ -482,13 +584,15 @@ func TestRestore(t *testing.T) { for _, next := range []tt{ {"select @foobar", "SELECT @`foobar`"}, } { - t.Run(next.input, func(t *testing.T) { - _, stmt, err := Parse(next.input) - assert.NoError(t, err) - var sb strings.Builder - err = stmt.Restore(RestoreDefault, &sb, nil) - assert.NoError(t, err) - assert.Equal(t, next.output, sb.String()) - }) + t.Run( + next.input, func(t *testing.T) { + _, stmt, err := Parse(next.input) + assert.NoError(t, err) + var sb strings.Builder + err = stmt.Restore(RestoreDefault, &sb, nil) + assert.NoError(t, err) + assert.Equal(t, next.output, sb.String()) + }, + ) } }