Skip to content

Commit

Permalink
[DDL] Rename table and Copy table to use getTableNameFromNode (#622)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Dec 19, 2024
1 parent 2783cac commit 3ba635e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 16 deletions.
4 changes: 2 additions & 2 deletions lib/antlr/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ func processCopyTable(ctx *generated.CopyCreateTableContext) (Event, error) {
return nil, fmt.Errorf("expected exactly 2 table names, got %d", len(tableNames))
}

tableName, err := getTextFromSingleNodeBranch(tableNames[0])
tableName, err := getTableNameFromNode(tableNames[0])
if err != nil {
return nil, err
}

copiedFromTableName, err := getTextFromSingleNodeBranch(tableNames[1])
copiedFromTableName, err := getTableNameFromNode(tableNames[1])
if err != nil {
return nil, err
}
Expand Down
33 changes: 23 additions & 10 deletions lib/antlr/create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,10 @@ import (

func TestCreateTable(t *testing.T) {
{
// Create table LIKE
sameQueries := []string{
"CREATE TABLE table_name LIKE other_table;",
"create table table_name (like other_table);",
}

for _, query := range sameQueries {
events, err := Parse(query)
{
// Create table LIKE by specifying schema
events, err := Parse("CREATE TABLE db_name.table_name LIKE db_name.other_table;")
assert.NoError(t, err)
assert.Len(t, events, 1)

createTableEvent, isOk := events[0].(CopyTableEvent)
assert.True(t, isOk)
Expand All @@ -27,7 +21,26 @@ func TestCreateTable(t *testing.T) {
assert.Len(t, createTableEvent.GetColumns(), 0)
assert.Equal(t, "other_table", createTableEvent.GetCopyFromTableName())
}

{
// Create table LIKE
sameQueries := []string{
"CREATE TABLE table_name LIKE other_table;",
"create table table_name (like other_table);",
}

for _, query := range sameQueries {
events, err := Parse(query)
assert.NoError(t, err)
assert.Len(t, events, 1)

createTableEvent, isOk := events[0].(CopyTableEvent)
assert.True(t, isOk)

assert.Equal(t, "table_name", createTableEvent.GetTable())
assert.Len(t, createTableEvent.GetColumns(), 0)
assert.Equal(t, "other_table", createTableEvent.GetCopyFromTableName())
}
}
}
{
// Create table with column as CHARACTER SET and collation specified at the column level
Expand Down
7 changes: 6 additions & 1 deletion lib/antlr/rename_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ func processRenameTable(ctx *generated.RenameTableContext) ([]Event, error) {
case *generated.RenameTableClauseContext:
var allTableNames []string
for _, tableName := range castedChild.AllTableName() {
allTableNames = append(allTableNames, tableName.GetText())
parsedTableName, err := getTableNameFromNode(tableName)
if err != nil {
return nil, fmt.Errorf("failed to get table name: %w", err)
}

allTableNames = append(allTableNames, parsedTableName)
}

// Must be at least two table names
Expand Down
6 changes: 3 additions & 3 deletions lib/antlr/rename_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ func TestRenameTable(t *testing.T) {
}
{
// Another one table variant
events, err := Parse(`RENAME TABLE current_db.tbl_name TO other_db.tbl_name;`)
events, err := Parse(`RENAME TABLE current_db.tbl_name TO current_db.tbl_name;`)
assert.NoError(t, err)
assert.Len(t, events, 1)

renameTableEvent, isOk := events[0].(RenameTableEvent)
assert.True(t, isOk)

assert.Equal(t, "current_db.tbl_name", renameTableEvent.GetTable())
assert.Equal(t, "other_db.tbl_name", renameTableEvent.GetNewTableName())
assert.Equal(t, "tbl_name", renameTableEvent.GetTable())
assert.Equal(t, "tbl_name", renameTableEvent.GetNewTableName())
}
{
// Multiple tables
Expand Down
1 change: 1 addition & 0 deletions lib/antlr/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func getTextFromSingleNodeBranch(tree antlr.Tree) (string, error) {
return getTextFromSingleNodeBranch(tree.GetChild(0))
}

// TODO: Extend this function to return the schema (if present)
func getTableNameFromNode(ctx generated.ITableNameContext) (string, error) {
children := ctx.GetChildren()
if len(children) != 1 {
Expand Down

0 comments on commit 3ba635e

Please sign in to comment.