From 260acc3da393916f0f0b3c919d76a1a5402fbde9 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 14 Oct 2024 14:21:25 -0700 Subject: [PATCH 1/2] Bug fix: allow quoted column names in COPY statement --- server/node/copy_from.go | 4 +++- testing/bats/dataloading.bats | 21 +++++++++++++++++++ .../tab-load-with-quoted-column-names.sql | 16 ++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 testing/bats/dataloading/tab-load-with-quoted-column-names.sql diff --git a/server/node/copy_from.go b/server/node/copy_from.go index adce19af3d..d34488b1e7 100644 --- a/server/node/copy_from.go +++ b/server/node/copy_from.go @@ -18,6 +18,7 @@ import ( "bufio" "fmt" "os" + "strings" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/go-mysql-server/sql" @@ -109,7 +110,8 @@ func (cf *CopyFrom) Validate(ctx *sql.Context) error { for i, col := range table.Schema() { name := cf.Columns[i] - if name.String() != col.Name { + nameString := strings.Trim(name.String(), `"`) + if nameString != col.Name { return fmt.Errorf("invalid column name list for table %s: %v", table.Name(), cf.Columns) } } diff --git a/testing/bats/dataloading.bats b/testing/bats/dataloading.bats index 54430fbcf7..c8019fa8b6 100644 --- a/testing/bats/dataloading.bats +++ b/testing/bats/dataloading.bats @@ -89,6 +89,27 @@ teardown() { [[ "$output" =~ "3 | 03 | 97302 | Guyane" ]] || false } +# Tests that we can load tabular data dump files that contain quoted column names +@test 'dataloading: tabular import, with quoted column names' { + # Import the data dump and assert the expected output + run query_server -f $BATS_TEST_DIRNAME/dataloading/tab-load-with-quoted-column-names.sql + [ "$status" -eq 0 ] + [[ "$output" =~ "COPY 3" ]] || false + [[ ! "$output" =~ "ERROR" ]] || false + + # Check the row count of imported tables + run query_server -c "SELECT count(*) from Regions;" + [ "$status" -eq 0 ] + [[ "$output" =~ "3" ]] || false + + # Check the inserted rows + run query_server -c "SELECT * from Regions;" + [ "$status" -eq 0 ] + [[ "$output" =~ "1 | 01 | 97105 | Guadeloupe" ]] || false + [[ "$output" =~ "2 | 02 | 97209 | Martinique" ]] || false + [[ "$output" =~ "3 | 03 | 97302 | Guyane" ]] || false +} + # Tests that we can load tabular data dump files that do not explicitly manage the session's transaction. @test 'dataloading: tabular import, no explicit tx management' { # Import the data dump and assert the expected output diff --git a/testing/bats/dataloading/tab-load-with-quoted-column-names.sql b/testing/bats/dataloading/tab-load-with-quoted-column-names.sql new file mode 100644 index 0000000000..1fa1e8ffc0 --- /dev/null +++ b/testing/bats/dataloading/tab-load-with-quoted-column-names.sql @@ -0,0 +1,16 @@ +BEGIN; + +CREATE TABLE Regions ( + "Id" SERIAL UNIQUE NOT NULL, + "Code" VARCHAR(4) UNIQUE NOT NULL, + "Capital" VARCHAR(10) NOT NULL, + "Name" VARCHAR(255) UNIQUE NOT NULL +); + +COPY regions ("Id", "Code", "Capital", "Name") FROM stdin; +1 01 97105 Guadeloupe +2 02 97209 Martinique +3 03 97302 Guyane +\. + +COMMIT; From 49f98a0e6267cc6f3443faa3819d6c0a239ff276 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Mon, 14 Oct 2024 14:49:02 -0700 Subject: [PATCH 2/2] Bug fix: allow quoted column names in nextval and setval functions --- server/functions/nextval.go | 26 ++++----------------- server/functions/setval.go | 36 ++++++++++++++++++++++++++--- testing/go/sequences_test.go | 45 ++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/server/functions/nextval.go b/server/functions/nextval.go index 7085be7925..9d165bd670 100644 --- a/server/functions/nextval.go +++ b/server/functions/nextval.go @@ -15,9 +15,6 @@ package functions import ( - "fmt" - "strings" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/core" @@ -38,26 +35,11 @@ var nextval_text = framework.Function1{ IsNonDeterministic: true, Strict: true, Callable: func(ctx *sql.Context, _ [2]pgtypes.DoltgresType, val any) (any, error) { - var schema, sequence string - var err error - pathElems := strings.Split(val.(string), ".") - switch len(pathElems) { - case 1: - schema, err = core.GetCurrentSchema(ctx) - if err != nil { - return nil, err - } - sequence = pathElems[0] - case 2: - schema = pathElems[0] - sequence = pathElems[1] - case 3: - // database is not used atm - schema = pathElems[1] - sequence = pathElems[2] - default: - return nil, fmt.Errorf(`cannot find sequence "%s" to get its nextval`, val.(string)) + schema, sequence, err := parseRelationName(ctx, val.(string)) + if err != nil { + return nil, err } + collection, err := core.GetCollectionFromContext(ctx) if err != nil { return nil, err diff --git a/server/functions/setval.go b/server/functions/setval.go index f195b979ec..d1a739543b 100644 --- a/server/functions/setval.go +++ b/server/functions/setval.go @@ -15,6 +15,9 @@ package functions import ( + "fmt" + "strings" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/core" @@ -54,11 +57,38 @@ var setval_text_int64_boolean = framework.Function3{ return nil, err } // TODO: this should take a regclass as the parameter to determine the schema - schema, err := core.GetCurrentSchema(ctx) + schema, relation, err := parseRelationName(ctx, val1.(string)) if err != nil { return nil, err } - - return val2.(int64), collection.SetVal(schema, val1.(string), val2.(int64), val3.(bool)) + return val2.(int64), collection.SetVal(schema, relation, val2.(int64), val3.(bool)) }, } + +// parseRelationName parses the schema and relation name from a relation name string, including trimming any +// identifier quotes used in the name. For example, passing in 'public."MyTable"' would return 'public' and 'MyTable'. +func parseRelationName(ctx *sql.Context, name string) (schema string, relation string, err error) { + pathElems := strings.Split(name, ".") + switch len(pathElems) { + case 1: + schema, err = core.GetCurrentSchema(ctx) + if err != nil { + return "", "", err + } + relation = pathElems[0] + case 2: + schema = pathElems[0] + relation = pathElems[1] + case 3: + // database is not used atm + schema = pathElems[1] + relation = pathElems[2] + default: + return "", "", fmt.Errorf(`cannot parse relation: %s`, name) + } + + // Trim any quotes from the relation name + relation = strings.Trim(relation, `"`) + + return schema, relation, nil +} diff --git a/testing/go/sequences_test.go b/testing/go/sequences_test.go index cbe0bc4d43..618db4e018 100644 --- a/testing/go/sequences_test.go +++ b/testing/go/sequences_test.go @@ -403,6 +403,38 @@ func TestSequences(t *testing.T) { }, }, }, + { + Name: "nextval() with double-quoted identifiers", + SetUpScript: []string{ + "CREATE SEQUENCE test_sequence;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT nextval('test_sequence');", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "SELECT nextval('public.test_sequence');", + Expected: []sql.Row{ + {2}, + }, + }, + { + Query: `SELECT nextval('"test_sequence"');`, + Expected: []sql.Row{ + {3}, + }, + }, + { + Query: `SELECT nextval('public."test_sequence"');`, + Expected: []sql.Row{ + {4}, + }, + }, + }, + }, { Name: "nextval() in filter", Skip: true, // GMS seems to call nextval once and cache the value, which is incorrect here @@ -539,6 +571,19 @@ func TestSequences(t *testing.T) { Query: "SELECT nextval('test4');", Expected: []sql.Row{{7}}, }, + { + Query: "CREATE SEQUENCE test5;", + Expected: []sql.Row{}, + }, + { + // test with a double-quoted identifier + Query: `SELECT setval('public."test5"', 100, true);`, + Expected: []sql.Row{{100}}, + }, + { + Query: "SELECT nextval('test5');", + Expected: []sql.Row{{101}}, + }, }, }, {